diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 9062b3b..5d3c3e5 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -1,7 +1,7 @@ Code Documentation ================== Welcome to PINA documentation! Here you can find the modules of the package divided in different sections. -The high-level structure of the package is depicted in our API. +The high-level structure of the package is depicted in our API. .. figure:: ../index_files/API_color.png :alt: PINA application program interface @@ -33,7 +33,7 @@ Solvers .. toctree:: :titlesonly: - + SolverInterface PINNInterface PINN @@ -82,13 +82,14 @@ Layers Proper Orthogonal Decomposition Periodic Boundary Condition Embedding Fourier Feature Embedding + Radial Basis Function Interpolation Adaptive Activation Functions ------------------------------- .. toctree:: :titlesonly: - + Adaptive Function Interface Adaptive ReLU Adaptive Sigmoid @@ -102,14 +103,14 @@ Adaptive Activation Functions Adaptive Softmax Adaptive SIREN Adaptive Exp - + Equations and Operators ------------------------- .. toctree:: :titlesonly: - + Equations Differential Operators @@ -166,4 +167,4 @@ Metrics and Losses LossInterface LpLoss - PowerLoss \ No newline at end of file + PowerLoss diff --git a/docs/source/_rst/_tutorial.rst b/docs/source/_rst/_tutorial.rst index 756d42e..4e2d205 100644 --- a/docs/source/_rst/_tutorial.rst +++ b/docs/source/_rst/_tutorial.rst @@ -43,4 +43,4 @@ Supervised Learning :titlesonly: Unstructured convolutional autoencoder via continuous convolution - POD-NN for reduced order modeling + POD-RBF and POD-NN for reduced order modeling diff --git a/docs/source/_rst/layers/rbf_layer.rst b/docs/source/_rst/layers/rbf_layer.rst new file mode 100644 index 0000000..8736d1a --- /dev/null +++ b/docs/source/_rst/layers/rbf_layer.rst @@ -0,0 +1,7 @@ +RBFBlock +====================== +.. currentmodule:: pina.model.layers.rbf_layer + +.. autoclass:: RBFBlock + :members: + :show-inheritance: diff --git a/docs/source/_rst/tutorials/tutorial8/tutorial.rst b/docs/source/_rst/tutorials/tutorial8/tutorial.rst index b160e09..5e6dca9 100644 --- a/docs/source/_rst/tutorials/tutorial8/tutorial.rst +++ b/docs/source/_rst/tutorials/tutorial8/tutorial.rst @@ -1,18 +1,20 @@ -Tutorial: Reduced order model (PODNN) for parametric problems -=============================================================== +Tutorial: Reduced order model (POD-RBF or POD-NN) for parametric problems +========================================================================= The tutorial aims to show how to employ the **PINA** library in order to apply a reduced order modeling technique [1]. Such methodologies have several similarities with machine learning approaches, since the main -goal consists of predicting the solution of differential equations +goal consists in predicting the solution of differential equations (typically parametric PDEs) in a real-time fashion. In particular we are going to use the Proper Orthogonal Decomposition -with Neural Network (PODNN) [2], which basically perform a dimensional -reduction using the POD approach, approximating the parametric solution -manifold (at the reduced space) using a NN. In this example, we use a -simple multilayer perceptron, but the plenty of different archiutectures -can be plugged as well. +with either Radial Basis Function Interpolation(POD-RBF) or Neural +Network (POD-NN) [2]. Here we basically perform a dimensional reduction +using the POD approach, and approximating the parametric solution +manifold (at the reduced space) using an interpolation (RBF) or a +regression technique (NN). In this example, we use a simple multilayer +perceptron, but the plenty of different architectures can be plugged as +well. References ^^^^^^^^^^ @@ -30,25 +32,25 @@ minimum PINA version to run this tutorial is the ``0.1``. .. code:: ipython3 %matplotlib inline - + import matplotlib.pyplot as plt import torch import pina - + from pina.geometry import CartesianDomain - + from pina.problem import ParametricProblem - from pina.model.layers import PODBlock + from pina.model.layers import PODBlock, RBFBlock from pina import Condition, LabelTensor, Trainer from pina.model import FeedForward from pina.solvers import SupervisedSolver - + print(f'We are using PINA version {pina.__version__}') .. parsed-literal:: - We are using PINA version 0.1 + We are using PINA version 0.1.1 We exploit the `Smithers `__ library to @@ -60,26 +62,27 @@ snapshots of the velocity (along :math:`x`, :math:`y`, and the magnitude) and pressure fields, and the corresponding parameter values. To visually check the snapshots, let’s plot also the data points and the -reference solution: this is the expected output of the neural network. +reference solution: this is the expected output of our model. .. code:: ipython3 from smithers.dataset import NavierStokesDataset dataset = NavierStokesDataset() - + fig, axs = plt.subplots(1, 4, figsize=(14, 3)) for ax, p, u in zip(axs, dataset.params[:4], dataset.snapshots['mag(v)'][:4]): ax.tricontourf(dataset.triang, u, levels=16) ax.set_title(f'$\mu$ = {p[0]:.2f}') -.. image:: tutorial_files/tutorial_5_1.png + +.. image:: tutorial_files/tutorial_5_0.png The *snapshots* - aka the numerical solutions computed for several parameters - and the corresponding parameters are the only data we need -to train the model, in order to predict for any new test parameter the -solution. To properly validate the accuracy, we initially split the 500 +to train the model, in order to predict the solution for any new test +parameter. To properly validate the accuracy, we initially split the 500 snapshots into the training dataset (90% of the original data) and the testing one (the reamining 10%). It must be said that, to plug the snapshots into **PINA**, we have to cast them to ``LabelTensor`` @@ -89,10 +92,10 @@ objects. u = torch.tensor(dataset.snapshots['mag(v)']).float() p = torch.tensor(dataset.params).float() - + p = LabelTensor(p, labels=['mu']) u = LabelTensor(u, labels=[f's{i}' for i in range(u.shape[1])]) - + ratio_train_test = 0.9 n = u.shape n_train = int(u.shape[0] * ratio_train_test) @@ -109,17 +112,94 @@ methodology), just defining a simple *input-output* condition. class SnapshotProblem(ParametricProblem): output_variables = [f's{i}' for i in range(u.shape[1])] parameter_domain = CartesianDomain({'mu': [0, 100]}) - + conditions = { - 'io': Condition(input_points=p, output_points=u) + 'io': Condition(input_points=p_train, output_points=u_train) } -Then, we define the model we want to use: basically we have a MLP -architecture that takes in input the parameter and return the *modal -coefficients*, so the reduced dimension representation (the coordinates -in the POD space). Such latent variable is the projected to the original -space using the POD modes, which are computed and stored in the -``PODBlock`` object. + poisson_problem = SnapshotProblem() + +We can then build a ``PODRBF`` model (using a Radial Basis Function +interpolation as approximation) and a ``PODNN`` approach (using an MLP +architecture as approximation). + +POD-RBF reduced order model +--------------------------- + +Then, we define the model we want to use, with the POD (``PODBlock``) +and the RBF (``RBFBlock``) objects. + +.. code:: ipython3 + + class PODRBF(torch.nn.Module): + """ + Proper orthogonal decomposition with Radial Basis Function interpolation model. + """ + + def __init__(self, pod_rank, rbf_kernel): + """ + + """ + super().__init__() + + self.pod = PODBlock(pod_rank) + self.rbf = RBFBlock(kernel=rbf_kernel) + + + def forward(self, x): + """ + Defines the computation performed at every call. + + :param x: The tensor to apply the forward pass. + :type x: torch.Tensor + :return: the output computed by the model. + :rtype: torch.Tensor + """ + coefficents = self.rbf(x) + return self.pod.expand(coefficents) + + def fit(self, p, x): + """ + Call the :meth:`pina.model.layers.PODBlock.fit` method of the + :attr:`pina.model.layers.PODBlock` attribute to perform the POD, + and the :meth:`pina.model.layers.RBFBlock.fit` method of the + :attr:`pina.model.layers.RBFBlock` attribute to fit the interpolation. + """ + self.pod.fit(x) + self.rbf.fit(p, self.pod.reduce(x)) + +We can then fit the model and ask it to predict the required field for +unseen values of the parameters. Note that this model does not need a +``Trainer`` since it does not include any neural network or learnable +parameters. + +.. code:: ipython3 + + pod_rbf = PODRBF(pod_rank=20, rbf_kernel='thin_plate_spline') + pod_rbf.fit(p_train, u_train) + +.. code:: ipython3 + + u_test_rbf = pod_rbf(p_test) + u_train_rbf = pod_rbf(p_train) + + relative_error_train = torch.norm(u_train_rbf - u_train)/torch.norm(u_train) + relative_error_test = torch.norm(u_test_rbf - u_test)/torch.norm(u_test) + + print('Error summary for POD-RBF model:') + print(f' Train: {relative_error_train.item():e}') + print(f' Test: {relative_error_test.item():e}') + + +.. parsed-literal:: + + Error summary for POD-RBF model: + Train: 1.287801e-03 + Test: 1.217041e-03 + + +POD-NN reduced order model +-------------------------- .. code:: ipython3 @@ -127,13 +207,13 @@ space using the POD modes, which are computed and stored in the """ Proper orthogonal decomposition with neural network model. """ - + def __init__(self, pod_rank, layers, func): """ - + """ super().__init__() - + self.pod = PODBlock(pod_rank) self.nn = FeedForward( input_dimensions=1, @@ -141,12 +221,12 @@ space using the POD modes, which are computed and stored in the layers=layers, func=func ) - - + + def forward(self, x): """ Defines the computation performed at every call. - + :param x: The tensor to apply the forward pass. :type x: torch.Tensor :return: the output computed by the model. @@ -154,7 +234,7 @@ space using the POD modes, which are computed and stored in the """ coefficents = self.nn(x) return self.pod.expand(coefficents) - + def fit_pod(self, x): """ Just call the :meth:`pina.model.layers.PODBlock.fit` method of the @@ -164,29 +244,27 @@ space using the POD modes, which are computed and stored in the We highlight that the POD modes are directly computed by means of the singular value decomposition (computed over the input data), and not -trained using the back-propagation approach. Only the weights of the MLP +trained using the backpropagation approach. Only the weights of the MLP are actually trained during the optimization loop. .. code:: ipython3 - poisson_problem = SnapshotProblem() - pod_nn = PODNN(pod_rank=20, layers=[10, 10, 10], func=torch.nn.Tanh) - pod_nn.fit_pod(u) - - pinn_stokes = SupervisedSolver( - problem=poisson_problem, - model=pod_nn, + pod_nn.fit_pod(u_train) + + pod_nn_stokes = SupervisedSolver( + problem=poisson_problem, + model=pod_nn, optimizer=torch.optim.Adam, optimizer_kwargs={'lr': 0.0001}) -Now that we set the ``Problem`` and the ``Model``, we have just to train -the model and use it for predict the test snapshots. +Now that we have set the ``Problem`` and the ``Model``, we have just to +train the model and use it for predicting the test snapshots. .. code:: ipython3 trainer = Trainer( - solver=pinn_stokes, + solver=pod_nn_stokes, max_epochs=1000, batch_size=100, log_every_n_steps=5, @@ -196,15 +274,41 @@ the model and use it for predict the test snapshots. .. parsed-literal:: - `Trainer.fit` stopped: `max_epochs=1000` reached. + GPU available: True (cuda), used: False + TPU available: False, using: 0 TPU cores + IPU available: False, using: 0 IPUs + HPU available: False, using: 0 HPUs + /u/a/aivagnes/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`. + + | Name | Type | Params + ---------------------------------------- + 0 | _loss | MSELoss | 0 + 1 | _neural_net | Network | 460 + ---------------------------------------- + 460 Trainable params + 0 Non-trainable params + 460 Total params + 0.002 Total estimated model params size (MB) + /u/a/aivagnes/anaconda3/lib/python3.8/site-packages/torch/cuda/__init__.py:152: UserWarning: + Found GPU0 Quadro K600 which is of cuda capability 3.0. + PyTorch no longer supports this GPU because it is too old. + The minimum cuda capability supported by this library is 3.7. + + warnings.warn(old_gpu_warn % (d, name, major, minor, min_arch // 10, min_arch % 10)) + .. parsed-literal:: - Epoch 999: 100%|██████████| 5/5 [00:00<00:00, 248.36it/s, v_num=20, mean_loss=0.902] + Training: | | 0/? [00:00`_ + """ + def __init__( + self, + neighbors=None, + smoothing=0.0, + kernel="thin_plate_spline", + epsilon=None, + degree=None, + ): + """ + :param int neighbors: Number of neighbors to use for the + interpolation. + If ``None``, use all data points. + :param float smoothing: Smoothing parameter for the interpolation. + if 0.0, the interpolation is exact and no smoothing is applied. + :param str kernel: Radial basis function to use. Must be one of + ``linear``, ``thin_plate_spline``, ``cubic``, ``quintic``, + ``multiquadric``, ``inverse_multiquadric``, ``inverse_quadratic``, + or ``gaussian``. + :param float epsilon: Shape parameter that scaled the input to + the RBF. This defaults to 1 for kernels in ``scale_invariant`` + dictionary, and must be specified for other kernels. + :param int degree: Degree of the added polynomial. + For some kernels, there exists a minimum degree of the polynomial + such that the RBF is well-posed. Those minimum degrees are specified + in the `min_degree_funcs` dictionary above. If `degree` is less than + the minimum degree, a warning is raised and the degree is set to the + minimum value. + """ + + super().__init__() + check_consistency(neighbors, (int, type(None))) + check_consistency(smoothing, (int, float, torch.Tensor)) + check_consistency(kernel, str) + check_consistency(epsilon, (float, type(None))) + check_consistency(degree, (int, type(None))) + + self.neighbors = neighbors + self.smoothing = smoothing + self.kernel = kernel + self.epsilon = epsilon + self.degree = degree + self.powers = None + # initialize data points and values + self.y = None + self.d = None + # initialize attributes for the fitted model + self._shift = None + self._scale = None + self._coeffs = None + + @property + def smoothing(self): + """ + Smoothing parameter for the interpolation. + + :rtype: float + """ + return self._smoothing + + @smoothing.setter + def smoothing(self, value): + self._smoothing = value + + @property + def kernel(self): + """ + Radial basis function to use. + + :rtype: str + """ + return self._kernel + + @kernel.setter + def kernel(self, value): + if value not in radial_functions: + raise ValueError(f"Unknown kernel: {value}") + self._kernel = value.lower() + + @property + def epsilon(self): + """ + Shape parameter that scaled the input to the RBF. + + :rtype: float + """ + return self._epsilon + + @epsilon.setter + def epsilon(self, value): + if value is None: + if self.kernel in scale_invariant: + value = 1.0 + else: + raise ValueError("Must specify `epsilon` for this kernel.") + else: + value = float(value) + self._epsilon = value + + @property + def degree(self): + """ + Degree of the added polynomial. + + :rtype: int + """ + return self._degree + + @degree.setter + def degree(self, value): + min_degree = min_degree_funcs.get(self.kernel, -1) + if value is None: + value = max(min_degree, 0) + else: + value = int(value) + if value < -1: + raise ValueError("`degree` must be at least -1.") + if value < min_degree: + warnings.warn( + "`degree` is too small for this kernel. Setting to " + f"{min_degree}.", UserWarning, + ) + self._degree = value + + def _check_data(self, y, d): + if y.ndim != 2: + raise ValueError("y must be a 2-dimensional tensor.") + + if d.shape[0] != y.shape[0]: + raise ValueError( + "The first dim of d must have the same length as " + "the first dim of y." + ) + + if isinstance(self.smoothing, (int, float)): + self.smoothing = torch.full((y.shape[0],), self.smoothing + ).float().to(y.device) + + def fit(self, y, d): + """ + Fit the RBF interpolator to the data. + + :param torch.Tensor y: (n, d) tensor of data points. + :param torch.Tensor d: (n, m) tensor of data values. + """ + self._check_data(y, d) + + self.y = y + self.d = d + + if self.neighbors is None: + nobs = self.y.shape[0] + else: + raise NotImplementedError("neighbors currently not supported") + + powers = RBFBlock.monomial_powers(self.y.shape[1], self.degree).to( + y.device) + if powers.shape[0] > nobs: + raise ValueError("The data is not compatible with the " + "requested degree.") + + if self.neighbors is None: + self._shift, self._scale, self._coeffs = RBFBlock.solve(self.y, + self.d.reshape((self.y.shape[0], -1)), + self.smoothing, self.kernel, self.epsilon, powers) + + self.powers = powers + + def forward(self, x): + """ + Returns the interpolated data at the given points `x`. + + :param torch.Tensor x: `(n, d)` tensor of points at which + to query the interpolator + + :rtype: `(n, m)` torch.Tensor of interpolated data. + """ + if x.ndim != 2: + raise ValueError("`x` must be a 2-dimensional tensor.") + + nx, ndim = x.shape + if ndim != self.y.shape[1]: + raise ValueError( + "Expected the second dim of `x` to have length " + f"{self.y.shape[1]}." + ) + + kernel_func = radial_functions[self.kernel] + + yeps = self.y * self.epsilon + xeps = x * self.epsilon + xhat = (x - self._shift) / self._scale + + kv = RBFBlock.kernel_vector(xeps, yeps, kernel_func) + p = RBFBlock.polynomial_matrix(xhat, self.powers) + vec = torch.cat([kv, p], dim=1) + out = torch.matmul(vec, self._coeffs) + out = out.reshape((nx,) + self.d.shape[1:]) + return out + + @staticmethod + def kernel_vector(x, y, kernel_func): + """ + Evaluate radial functions with centers `y` for all points in `x`. + + :param torch.Tensor x: `(n, d)` tensor of points. + :param torch.Tensor y: `(m, d)` tensor of centers. + :param str kernel_func: Radial basis function to use. + + :rtype: `(n, m)` torch.Tensor of radial function values. + """ + return kernel_func(torch.cdist(x, y)) + + @staticmethod + def polynomial_matrix(x, powers): + """ + Evaluate monomials at `x` with given `powers`. + + :param torch.Tensor x: `(n, d)` tensor of points. + :param torch.Tensor powers: `(r, d)` tensor of powers for each monomial. + + :rtype: `(n, r)` torch.Tensor of monomial values. + """ + x_ = torch.repeat_interleave(x, repeats=powers.shape[0], dim=0) + powers_ = powers.repeat(x.shape[0], 1) + return torch.prod(x_**powers_, dim=1).view(x.shape[0], powers.shape[0]) + + @staticmethod + def kernel_matrix(x, kernel_func): + """ + Returns radial function values for all pairs of points in `x`. + + :param torch.Tensor x: `(n, d`) tensor of points. + :param str kernel_func: Radial basis function to use. + + :rtype: `(n, n`) torch.Tensor of radial function values. + """ + return kernel_func(torch.cdist(x, x)) + + @staticmethod + def monomial_powers(ndim, degree): + """ + Return the powers for each monomial in a polynomial. + + :param int ndim: Number of variables in the polynomial. + :param int degree: Degree of the polynomial. + + :rtype: `(nmonos, ndim)` torch.Tensor where each row contains the powers + for each variable in a monomial. + + """ + nmonos = math.comb(degree + ndim, ndim) + out = torch.zeros((nmonos, ndim), dtype=torch.int32) + count = 0 + for deg in range(degree + 1): + for mono in combinations_with_replacement(range(ndim), deg): + for var in mono: + out[count, var] += 1 + count += 1 + return out + + @staticmethod + def build(y, d, smoothing, kernel, epsilon, powers): + """ + Build the RBF linear system. + + :param torch.Tensor y: (n, d) tensor of data points. + :param torch.Tensor d: (n, m) tensor of data values. + :param torch.Tensor smoothing: (n,) tensor of smoothing parameters. + :param str kernel: Radial basis function to use. + :param float epsilon: Shape parameter that scaled the input to the RBF. + :param torch.Tensor powers: (r, d) tensor of powers for each monomial. + + :rtype: (lhs, rhs, shift, scale) where `lhs` and `rhs` are the + left-hand side and right-hand side of the linear system, and + `shift` and `scale` are the shift and scale parameters. + """ + p = d.shape[0] + s = d.shape[1] + r = powers.shape[0] + kernel_func = radial_functions[kernel] + + mins = torch.min(y, dim=0).values + maxs = torch.max(y, dim=0).values + shift = (maxs + mins) / 2 + scale = (maxs - mins) / 2 + + scale[scale == 0.0] = 1.0 + + yeps = y * epsilon + yhat = (y - shift) / scale + + lhs = torch.empty((p + r, p + r), device=d.device).float() + lhs[:p, :p] = RBFBlock.kernel_matrix(yeps, kernel_func) + lhs[:p, p:] = RBFBlock.polynomial_matrix(yhat, powers) + lhs[p:, :p] = lhs[:p, p:].T + lhs[p:, p:] = 0.0 + lhs[:p, :p] += torch.diag(smoothing) + + rhs = torch.empty((r + p, s), device=d.device).float() + rhs[:p] = d + rhs[p:] = 0.0 + return lhs, rhs, shift, scale + + @staticmethod + def solve(y, d, smoothing, kernel, epsilon, powers): + """ + Build then solve the RBF linear system. + + :param torch.Tensor y: (n, d) tensor of data points. + :param torch.Tensor d: (n, m) tensor of data values. + :param torch.Tensor smoothing: (n,) tensor of smoothing parameters. + + :param str kernel: Radial basis function to use. + :param float epsilon: Shape parameter that scaled the input to the RBF. + :param torch.Tensor powers: (r, d) tensor of powers for each monomial. + + :raises ValueError: If the linear system is singular. + + :rtype: (shift, scale, coeffs) where `shift` and `scale` are the + shift and scale parameters, and `coeffs` are the coefficients + of the interpolator + """ + + lhs, rhs, shift, scale = RBFBlock.build(y, d, smoothing, kernel, + epsilon, powers) + try: + coeffs = torch.linalg.solve(lhs, rhs) + except RuntimeError as e: + msg = "Singular matrix." + nmonos = powers.shape[0] + if nmonos > 0: + pmat = RBFBlock.polynomial_matrix((y - shift) / scale, powers) + rank = torch.linalg.matrix_rank(pmat) + if rank < nmonos: + msg = ( + "Singular matrix. The matrix of monomials evaluated at " + "the data point coordinates does not have full column " + f"rank ({rank}/{nmonos})." + ) + + raise ValueError(msg) from e + + return shift, scale, coeffs diff --git a/tests/test_layers/test_rbf.py b/tests/test_layers/test_rbf.py new file mode 100644 index 0000000..43f19f3 --- /dev/null +++ b/tests/test_layers/test_rbf.py @@ -0,0 +1,85 @@ +import torch +import pytest +import math + +from pina.model.layers.rbf_layer import RBFBlock + +x = torch.linspace(-1, 1, 100) +toy_params = torch.linspace(0, 1, 10).unsqueeze(1) +toy_snapshots = torch.vstack([torch.exp(-x**2)*c for c in toy_params]) +toy_params_test = torch.linspace(0, 1, 3).unsqueeze(1) +toy_snapshots_test = torch.vstack([torch.exp(-x**2)*c for c in toy_params_test]) + +kernels = ["linear", "thin_plate_spline", "cubic", "quintic", + "multiquadric", "inverse_multiquadric", "inverse_quadratic", "gaussian"] + +noscale_invariant_kernels = ["multiquadric", "inverse_multiquadric", + "inverse_quadratic", "gaussian"] + +scale_invariant_kernels = ["linear", "thin_plate_spline", "cubic", "quintic"] + +def test_constructor_default(): + rbf = RBFBlock() + assert rbf.kernel == "thin_plate_spline" + assert rbf.epsilon == 1 + assert rbf.smoothing == 0. + +@pytest.mark.parametrize("kernel", kernels) +@pytest.mark.parametrize("epsilon", [0.1, 1., 10.]) +def test_constructor_epsilon(kernel, epsilon): + if kernel in scale_invariant_kernels: + rbf = RBFBlock(kernel=kernel) + assert rbf.kernel == kernel + assert rbf.epsilon == 1 + elif kernel in noscale_invariant_kernels: + with pytest.raises(ValueError): + rbf = RBFBlock(kernel=kernel) + rbf = RBFBlock(kernel=kernel, epsilon=epsilon) + assert rbf.kernel == kernel + assert rbf.epsilon == epsilon + + assert rbf.smoothing == 0. + +@pytest.mark.parametrize("kernel", kernels) +@pytest.mark.parametrize("epsilon", [0.1, 1., 10.]) +@pytest.mark.parametrize("degree", [2, 3, 4]) +@pytest.mark.parametrize("smoothing", [1e-5, 1e-3, 1e-1]) +def test_constructor_all(kernel, epsilon, degree, smoothing): + rbf = RBFBlock(kernel=kernel, epsilon=epsilon, degree=degree, + smoothing=smoothing) + assert rbf.kernel == kernel + assert rbf.epsilon == epsilon + assert rbf.degree == degree + assert rbf.smoothing == smoothing + assert rbf.y == None + assert rbf.d == None + assert rbf.powers == None + assert rbf._shift == None + assert rbf._scale == None + assert rbf._coeffs == None + +def test_fit(): + rbf = RBFBlock() + rbf.fit(toy_params, toy_snapshots) + ndim = toy_params.shape[1] + torch.testing.assert_close(rbf.y, toy_params) + torch.testing.assert_close(rbf.d, toy_snapshots) + assert rbf.powers.shape == (math.comb(rbf.degree+ndim, ndim), ndim) + assert rbf._shift.shape == (ndim,) + assert rbf._scale.shape == (ndim,) + assert rbf._coeffs.shape == (rbf.powers.shape[0]+toy_snapshots.shape[0], toy_snapshots.shape[1]) + +def test_forward(): + rbf = RBFBlock() + rbf.fit(toy_params, toy_snapshots) + c = rbf(toy_params) + assert c.shape == toy_snapshots.shape + torch.testing.assert_close(c, toy_snapshots) + +def test_forward_unseen_parameters(): + rbf = RBFBlock() + rbf.fit(toy_params, toy_snapshots) + c = rbf(toy_params_test) + assert c.shape == toy_snapshots_test.shape + torch.testing.assert_close(c, toy_snapshots_test) + diff --git a/tutorials/README.md b/tutorials/README.md index aaccd57..5838a1f 100644 --- a/tutorials/README.md +++ b/tutorials/README.md @@ -32,4 +32,4 @@ Time dependent Kuramoto Sivashinsky equation using the Averaging Neural Operator | Description | Tutorial | |---------------|-----------| Unstructured convolutional autoencoder via continuous convolution |[[.ipynb](tutorial4/tutorial.ipynb), [.py](tutorial4/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial4/tutorial.html)]| -POD-NN for reduced order modeling| [[.ipynb](tutorial8/tutorial.ipynb), [.py](tutorial8/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial8/tutorial.html)]| +POD-RBF and POD-NN for reduced order modeling| [[.ipynb](tutorial8/tutorial.ipynb), [.py](tutorial8/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial8/tutorial.html)]| diff --git a/tutorials/tutorial8/tutorial.ipynb b/tutorials/tutorial8/tutorial.ipynb index 9115d12..fb368b5 100644 --- a/tutorials/tutorial8/tutorial.ipynb +++ b/tutorials/tutorial8/tutorial.ipynb @@ -5,7 +5,7 @@ "id": "dbbb73cb-a632-4056-bbca-b483b2ad5f9c", "metadata": {}, "source": [ - "# Tutorial: Reduced order model (PODNN) for parametric problems" + "# Tutorial: Reduced order model (POD-RBF or POD-NN) for parametric problems" ] }, { @@ -13,9 +13,9 @@ "id": "84508f26-1ba6-4b59-926b-3e340d632a15", "metadata": {}, "source": [ - "The tutorial aims to show how to employ the **PINA** library in order to apply a reduced order modeling technique [1]. Such methodologies have several similarities with machine learning approaches, since the main goal consists of predicting the solution of differential equations (typically parametric PDEs) in a real-time fashion.\n", + "The tutorial aims to show how to employ the **PINA** library in order to apply a reduced order modeling technique [1]. Such methodologies have several similarities with machine learning approaches, since the main goal consists in predicting the solution of differential equations (typically parametric PDEs) in a real-time fashion.\n", "\n", - "In particular we are going to use the Proper Orthogonal Decomposition with Neural Network (PODNN) [2], which basically performs a dimensional reduction using the POD approach, approximating the parametric solution manifold (at the reduced space) using a NN. In this example, we use a simple multilayer perceptron, but the plenty of different architectures can be plugged as well.\n", + "In particular we are going to use the Proper Orthogonal Decomposition with either Radial Basis Function Interpolation(POD-RBF) or Neural Network (POD-NN) [2]. Here we basically perform a dimensional reduction using the POD approach, and approximating the parametric solution manifold (at the reduced space) using an interpolation (RBF) or a regression technique (NN). In this example, we use a simple multilayer perceptron, but the plenty of different architectures can be plugged as well.\n", "\n", "#### References\n", "1. Rozza G., Stabile G., Ballarin F. (2022). Advanced Reduced Order Methods and Applications in Computational Fluid Dynamics, Society for Industrial and Applied Mathematics. \n", @@ -33,7 +33,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 1, "id": "00d1027d-13f2-4619-9ff7-a740568f13ff", "metadata": {}, "outputs": [ @@ -41,7 +41,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "We are using PINA version 0.1\n" + "We are using PINA version 0.1.1\n" ] } ], @@ -55,7 +55,7 @@ "from pina.geometry import CartesianDomain\n", "\n", "from pina.problem import ParametricProblem\n", - "from pina.model.layers import PODBlock\n", + "from pina.model.layers import PODBlock, RBFBlock\n", "from pina import Condition, LabelTensor, Trainer\n", "from pina.model import FeedForward\n", "from pina.solvers import SupervisedSolver\n", @@ -71,35 +71,25 @@ "We exploit the [Smithers](www.github.com/mathLab/Smithers) library to collect the parametric snapshots. In particular, we use the `NavierStokesDataset` class that contains a set of parametric solutions of the Navier-Stokes equations in a 2D L-shape domain. The parameter is the inflow velocity.\n", "The dataset is composed by 500 snapshots of the velocity (along $x$, $y$, and the magnitude) and pressure fields, and the corresponding parameter values.\n", "\n", - "To visually check the snapshots, let's plot also the data points and the reference solution: this is the expected output of the neural network." + "To visually check the snapshots, let's plot also the data points and the reference solution: this is the expected output of our model." ] }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 2, "id": "2c55d972-09a9-41de-9400-ba051c28cdcb", "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 0: 0%| | 0/5 [48:45" + "
" ] }, - "metadata": {}, + "metadata": { + "needs_background": "light" + }, "output_type": "display_data" } ], @@ -124,7 +114,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 3, "id": "bd081bcd-192f-4370-a013-9b73050b5383", "metadata": {}, "outputs": [], @@ -153,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 4, "id": "55cef553-7495-401d-9d17-1acff8ec5953", "metadata": {}, "outputs": [], @@ -163,8 +153,26 @@ " parameter_domain = CartesianDomain({'mu': [0, 100]})\n", "\n", " conditions = {\n", - " 'io': Condition(input_points=p, output_points=u)\n", - " }" + " 'io': Condition(input_points=p_train, output_points=u_train)\n", + " }\n", + "\n", + "poisson_problem = SnapshotProblem()" + ] + }, + { + "cell_type": "markdown", + "id": "3b255526", + "metadata": {}, + "source": [ + "We can then build a `PODRBF` model (using a Radial Basis Function interpolation as approximation) and a `PODNN` approach (using an MLP architecture as approximation)." + ] + }, + { + "cell_type": "markdown", + "id": "352ac702", + "metadata": {}, + "source": [ + "## POD-RBF reduced order model" ] }, { @@ -172,12 +180,112 @@ "id": "6b264569-57b3-458d-bb69-8e94fe89017d", "metadata": {}, "source": [ - "Then, we define the model we want to use: an MLP architecture which takes in input the parameter and returns the *modal coefficients*, i.e.the interpolated coefficients of the POD expansion. Such coefficients are projected to the original space using the POD modes, which are computed and stored in the `PODBlock` object." + "Then, we define the model we want to use, with the POD (`PODBlock`) and the RBF (`RBFBlock`) objects." ] }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 5, + "id": "0bd2c30c", + "metadata": {}, + "outputs": [], + "source": [ + "class PODRBF(torch.nn.Module):\n", + " \"\"\"\n", + " Proper orthogonal decomposition with Radial Basis Function interpolation model.\n", + " \"\"\"\n", + "\n", + " def __init__(self, pod_rank, rbf_kernel):\n", + " \"\"\"\n", + " \n", + " \"\"\"\n", + " super().__init__()\n", + " \n", + " self.pod = PODBlock(pod_rank)\n", + " self.rbf = RBFBlock(kernel=rbf_kernel)\n", + " \n", + "\n", + " def forward(self, x):\n", + " \"\"\"\n", + " Defines the computation performed at every call.\n", + "\n", + " :param x: The tensor to apply the forward pass.\n", + " :type x: torch.Tensor\n", + " :return: the output computed by the model.\n", + " :rtype: torch.Tensor\n", + " \"\"\"\n", + " coefficents = self.rbf(x)\n", + " return self.pod.expand(coefficents)\n", + "\n", + " def fit(self, p, x):\n", + " \"\"\"\n", + " Call the :meth:`pina.model.layers.PODBlock.fit` method of the\n", + " :attr:`pina.model.layers.PODBlock` attribute to perform the POD,\n", + " and the :meth:`pina.model.layers.RBFBlock.fit` method of the\n", + " :attr:`pina.model.layers.RBFBlock` attribute to fit the interpolation.\n", + " \"\"\"\n", + " self.pod.fit(x)\n", + " self.rbf.fit(p, self.pod.reduce(x))" + ] + }, + { + "cell_type": "markdown", + "id": "4d2551ff", + "metadata": {}, + "source": [ + "We can then fit the model and ask it to predict the required field for unseen values of the parameters. Note that this model does not need a `Trainer` since it does not include any neural network or learnable parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "af0a7f9b", + "metadata": {}, + "outputs": [], + "source": [ + "pod_rbf = PODRBF(pod_rank=20, rbf_kernel='thin_plate_spline')\n", + "pod_rbf.fit(p_train, u_train)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "41a27834", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Error summary for POD-RBF model:\n", + " Train: 1.287801e-03\n", + " Test: 1.217041e-03\n" + ] + } + ], + "source": [ + "u_test_rbf = pod_rbf(p_test)\n", + "u_train_rbf = pod_rbf(p_train)\n", + "\n", + "relative_error_train = torch.norm(u_train_rbf - u_train)/torch.norm(u_train)\n", + "relative_error_test = torch.norm(u_test_rbf - u_test)/torch.norm(u_test)\n", + "\n", + "print('Error summary for POD-RBF model:')\n", + "print(f' Train: {relative_error_train.item():e}')\n", + "print(f' Test: {relative_error_test.item():e}')" + ] + }, + { + "cell_type": "markdown", + "id": "a5bac005", + "metadata": {}, + "source": [ + "## POD-NN reduced order model" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "id": "c4170514-eb73-488e-8942-0129070e4e13", "metadata": {}, "outputs": [], @@ -232,17 +340,15 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 9, "id": "e998cad5-e3a7-4a3b-a1a5-400b6ff575a1", "metadata": {}, "outputs": [], "source": [ - "poisson_problem = SnapshotProblem()\n", - "\n", "pod_nn = PODNN(pod_rank=20, layers=[10, 10, 10], func=torch.nn.Tanh)\n", - "pod_nn.fit_pod(u)\n", + "pod_nn.fit_pod(u_train)\n", "\n", - "pinn_stokes = SupervisedSolver(\n", + "pod_nn_stokes = SupervisedSolver(\n", " problem=poisson_problem, \n", " model=pod_nn, \n", " optimizer=torch.optim.Adam,\n", @@ -259,7 +365,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 10, "id": "f1e94f42-cf80-4ca7-bb5e-ad47c1dd2784", "metadata": {}, "outputs": [ @@ -267,10 +373,11 @@ "name": "stderr", "output_type": "stream", "text": [ - "GPU available: False, used: False\n", + "GPU available: True (cuda), used: False\n", "TPU available: False, using: 0 TPU cores\n", "IPU available: False, using: 0 IPUs\n", "HPU available: False, using: 0 HPUs\n", + "/u/a/aivagnes/anaconda3/lib/python3.8/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.\n", "\n", " | Name | Type | Params\n", "----------------------------------------\n", @@ -280,15 +387,28 @@ "460 Trainable params\n", "0 Non-trainable params\n", "460 Total params\n", - "0.002 Total estimated model params size (MB)\n" + "0.002 Total estimated model params size (MB)\n", + "/u/a/aivagnes/anaconda3/lib/python3.8/site-packages/torch/cuda/__init__.py:152: UserWarning: \n", + " Found GPU0 Quadro K600 which is of cuda capability 3.0.\n", + " PyTorch no longer supports this GPU because it is too old.\n", + " The minimum cuda capability supported by this library is 3.7.\n", + " \n", + " warnings.warn(old_gpu_warn % (d, name, major, minor, min_arch // 10, min_arch % 10))\n" ] }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 999: 100%|██████████| 5/5 [00:00<00:00, 286.50it/s, v_num=20, mean_loss=0.902]" - ] + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a5ebdb14ddcb457da6d72432a4aa7a61", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Training: | | 0/? [00:00" + "
" ] }, - "metadata": {}, + "metadata": { + "needs_background": "light" + }, "output_type": "display_data" } ], "source": [ - "idx = torch.randint(0, len(u_test_pred), (4,))\n", - "u_idx = pinn_stokes(p_test[idx])\n", + "idx = torch.randint(0, len(u_test), (4,))\n", + "u_idx_rbf = pod_rbf(p_test[idx])\n", + "u_idx_nn = pod_nn_stokes(p_test[idx])\n", + "\n", "import numpy as np\n", "import matplotlib\n", - "fig, axs = plt.subplots(3, 4, figsize=(14, 9))\n", + "import matplotlib.pyplot as plt\n", "\n", - "relative_error = np.abs(u_test[idx] - u_idx.detach())\n", - "relative_error = np.where(u_test[idx] < 1e-7, 1e-7, relative_error/u_test[idx])\n", + "fig, axs = plt.subplots(5, 4, figsize=(14, 9))\n", + "\n", + "relative_error_rbf = np.abs(u_test[idx] - u_idx_rbf.detach())\n", + "relative_error_rbf = np.where(u_test[idx] < 1e-7, 1e-7, relative_error_rbf/u_test[idx])\n", + "\n", + "relative_error_nn = np.abs(u_test[idx] - u_idx_nn.detach())\n", + "relative_error_nn = np.where(u_test[idx] < 1e-7, 1e-7, relative_error_nn/u_test[idx])\n", " \n", - "for i, (idx_, u_, err_) in enumerate(zip(idx, u_idx, relative_error)):\n", - " cm = axs[0, i].tricontourf(dataset.triang, u_.detach())\n", + "for i, (idx_, rbf_, nn_, rbf_err_, nn_err_) in enumerate(\n", + " zip(idx, u_idx_rbf, u_idx_nn, relative_error_rbf, relative_error_nn)):\n", " axs[0, i].set_title(f'$\\mu$ = {p_test[idx_].item():.2f}')\n", - " plt.colorbar(cm)\n", - "\n", - " cm = axs[1, i].tricontourf(dataset.triang, u_test[idx_].flatten())\n", - " plt.colorbar(cm)\n", - "\n", - " cm = axs[2, i].tripcolor(dataset.triang, err_, norm=matplotlib.colors.LogNorm())\n", - " plt.colorbar(cm)\n", + " \n", + " cm = axs[0, i].tricontourf(dataset.triang, rbf_.detach()) # POD-RBF prediction\n", + " plt.colorbar(cm, ax=axs[0, i])\n", + " \n", + " cm = axs[1, i].tricontourf(dataset.triang, nn_.detach()) # POD-NN prediction\n", + " plt.colorbar(cm, ax=axs[1, i])\n", "\n", + " cm = axs[2, i].tricontourf(dataset.triang, u_test[idx_].flatten()) # Truth\n", + " plt.colorbar(cm, ax=axs[2, i])\n", "\n", + " cm = axs[3, i].tripcolor(dataset.triang, rbf_err_, norm=matplotlib.colors.LogNorm()) # Error for POD-RBF\n", + " plt.colorbar(cm, ax=axs[3, i])\n", + " \n", + " cm = axs[4, i].tripcolor(dataset.triang, nn_err_, norm=matplotlib.colors.LogNorm()) # Error for POD-NN\n", + " plt.colorbar(cm, ax=axs[4, i])\n", + " \n", "plt.show()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d3758c39", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3.9.18 ('gridcal')", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -418,7 +563,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.18" + "version": "3.8.8" }, "vscode": { "interpreter": { diff --git a/tutorials/tutorial8/tutorial.py b/tutorials/tutorial8/tutorial.py index 49b0b93..310c04b 100644 --- a/tutorials/tutorial8/tutorial.py +++ b/tutorials/tutorial8/tutorial.py @@ -1,20 +1,20 @@ #!/usr/bin/env python # coding: utf-8 -# # Tutorial: Reduced order model (PODNN) for parametric problems +# # Tutorial: Reduced order model (POD-RBF and POD-NN) for parametric problems -# The tutorial aims to show how to employ the **PINA** library in order to apply a reduced order modeling technique [1]. Such methodologies have several similarities with machine learning approaches, since the main goal consists of predicting the solution of differential equations (typically parametric PDEs) in a real-time fashion. -# -# In particular we are going to use the Proper Orthogonal Decomposition with Neural Network (PODNN) [2], which basically perform a dimensional reduction using the POD approach, approximating the parametric solution manifold (at the reduced space) using a NN. In this example, we use a simple multilayer perceptron, but the plenty of different archiutectures can be plugged as well. -# +# The tutorial aims to show how to employ the **PINA** library in order to apply a reduced order modeling technique [1]. Such methodologies have several similarities with machine learning approaches, since the main goal consists in predicting the solution of differential equations (typically parametric PDEs) in a real-time fashion. +# +# In particular we are going to use the Proper Orthogonal Decomposition with either Radial Basis Function Interpolation(POD-RBF) or Neural Network (POD-NN) [2]. Here we basically perform a dimensional reduction using the POD approach, and approximating the parametric solution manifold (at the reduced space) using an interpolation (RBF) or a regression technique (NN). In this example, we use a simple multilayer perceptron, but the plenty of different architectures can be plugged as well. +# # #### References -# 1. Rozza G., Stabile G., Ballarin F. (2022). Advanced Reduced Order Methods and Applications in Computational Fluid Dynamics, Society for Industrial and Applied Mathematics. +# 1. Rozza G., Stabile G., Ballarin F. (2022). Advanced Reduced Order Methods and Applications in Computational Fluid Dynamics, Society for Industrial and Applied Mathematics. # 2. Hesthaven, J. S., & Ubbiali, S. (2018). Non-intrusive reduced order modeling of nonlinear problems using neural networks. Journal of Computational Physics, 363, 55-78. # Let's start with the necessary imports. # It's important to note the minimum PINA version to run this tutorial is the `0.1`. -# In[29]: +# In[1]: get_ipython().run_line_magic('matplotlib', 'inline') @@ -26,7 +26,7 @@ import pina from pina.geometry import CartesianDomain from pina.problem import ParametricProblem -from pina.model.layers import PODBlock +from pina.model.layers import PODBlock, RBFBlock from pina import Condition, LabelTensor, Trainer from pina.model import FeedForward from pina.solvers import SupervisedSolver @@ -36,10 +36,10 @@ print(f'We are using PINA version {pina.__version__}') # We exploit the [Smithers](www.github.com/mathLab/Smithers) library to collect the parametric snapshots. In particular, we use the `NavierStokesDataset` class that contains a set of parametric solutions of the Navier-Stokes equations in a 2D L-shape domain. The parameter is the inflow velocity. # The dataset is composed by 500 snapshots of the velocity (along $x$, $y$, and the magnitude) and pressure fields, and the corresponding parameter values. -# -# To visually check the snapshots, let's plot also the data points and the reference solution: this is the expected output of the neural network. +# +# To visually check the snapshots, let's plot also the data points and the reference solution: this is the expected output of our model. -# In[30]: +# In[2]: from smithers.dataset import NavierStokesDataset @@ -51,10 +51,10 @@ for ax, p, u in zip(axs, dataset.params[:4], dataset.snapshots['mag(v)'][:4]): ax.set_title(f'$\mu$ = {p[0]:.2f}') -# The *snapshots* - aka the numerical solutions computed for several parameters - and the corresponding parameters are the only data we need to train the model, in order to predict for any new test parameter the solution. +# The *snapshots* - aka the numerical solutions computed for several parameters - and the corresponding parameters are the only data we need to train the model, in order to predict the solution for any new test parameter. # To properly validate the accuracy, we initially split the 500 snapshots into the training dataset (90% of the original data) and the testing one (the reamining 10%). It must be said that, to plug the snapshots into **PINA**, we have to cast them to `LabelTensor` objects. -# In[31]: +# In[3]: u = torch.tensor(dataset.snapshots['mag(v)']).float() @@ -73,7 +73,7 @@ p_train, p_test = p[:n_train], p[n_train:] # It is now time to define the problem! We inherit from `ParametricProblem` (since the space invariant typically of this methodology), just defining a simple *input-output* condition. -# In[32]: +# In[4]: class SnapshotProblem(ParametricProblem): @@ -81,13 +81,85 @@ class SnapshotProblem(ParametricProblem): parameter_domain = CartesianDomain({'mu': [0, 100]}) conditions = { - 'io': Condition(input_points=p, output_points=u) + 'io': Condition(input_points=p_train, output_points=u_train) } +poisson_problem = SnapshotProblem() -# Then, we define the model we want to use: basically we have a MLP architecture that takes in input the parameter and return the *modal coefficients*, so the reduced dimension representation (the coordinates in the POD space). Such latent variable is the projected to the original space using the POD modes, which are computed and stored in the `PODBlock` object. -# In[33]: +# We can then build a `PODRBF` model (using a Radial Basis Function interpolation as approximation) and a `PODNN` approach (using an MLP architecture as approximation). + +# ## POD-RBF reduced order model + +# Then, we define the model we want to use, with the POD (`PODBlock`) and the RBF (`RBFBlock`) objects. + +# In[5]: + + +class PODRBF(torch.nn.Module): + """ + Proper orthogonal decomposition with Radial Basis Function interpolation model. + """ + + def __init__(self, pod_rank, rbf_kernel): + """ + + """ + super().__init__() + + self.pod = PODBlock(pod_rank) + self.rbf = RBFBlock(kernel=rbf_kernel) + + + def forward(self, x): + """ + Defines the computation performed at every call. + + :param x: The tensor to apply the forward pass. + :type x: torch.Tensor + :return: the output computed by the model. + :rtype: torch.Tensor + """ + coefficents = self.rbf(x) + return self.pod.expand(coefficents) + + def fit(self, p, x): + """ + Call the :meth:`pina.model.layers.PODBlock.fit` method of the + :attr:`pina.model.layers.PODBlock` attribute to perform the POD, + and the :meth:`pina.model.layers.RBFBlock.fit` method of the + :attr:`pina.model.layers.RBFBlock` attribute to fit the interpolation. + """ + self.pod.fit(x) + self.rbf.fit(p, self.pod.reduce(x)) + + +# We can then fit the model and ask it to predict the required field for unseen values of the parameters. Note that this model does not need a `Trainer` since it does not include any neural network or learnable parameters. + +# In[6]: + + +pod_rbf = PODRBF(pod_rank=20, rbf_kernel='thin_plate_spline') +pod_rbf.fit(p_train, u_train) + + +# In[7]: + + +u_test_rbf = pod_rbf(p_test) +u_train_rbf = pod_rbf(p_train) + +relative_error_train = torch.norm(u_train_rbf - u_train)/torch.norm(u_train) +relative_error_test = torch.norm(u_test_rbf - u_test)/torch.norm(u_test) + +print('Error summary for POD-RBF model:') +print(f' Train: {relative_error_train.item():e}') +print(f' Test: {relative_error_test.item():e}') + + +# ## POD-NN reduced order model + +# In[8]: class PODNN(torch.nn.Module): @@ -97,10 +169,10 @@ class PODNN(torch.nn.Module): def __init__(self, pod_rank, layers, func): """ - + """ super().__init__() - + self.pod = PODBlock(pod_rank) self.nn = FeedForward( input_dimensions=1, @@ -108,7 +180,7 @@ class PODNN(torch.nn.Module): layers=layers, func=func ) - + def forward(self, x): """ @@ -130,30 +202,28 @@ class PODNN(torch.nn.Module): self.pod.fit(x) -# We highlight that the POD modes are directly computed by means of the singular value decomposition (computed over the input data), and not trained using the back-propagation approach. Only the weights of the MLP are actually trained during the optimization loop. +# We highlight that the POD modes are directly computed by means of the singular value decomposition (computed over the input data), and not trained using the backpropagation approach. Only the weights of the MLP are actually trained during the optimization loop. -# In[34]: +# In[9]: -poisson_problem = SnapshotProblem() - pod_nn = PODNN(pod_rank=20, layers=[10, 10, 10], func=torch.nn.Tanh) -pod_nn.fit_pod(u) +pod_nn.fit_pod(u_train) -pinn_stokes = SupervisedSolver( - problem=poisson_problem, - model=pod_nn, +pod_nn_stokes = SupervisedSolver( + problem=poisson_problem, + model=pod_nn, optimizer=torch.optim.Adam, optimizer_kwargs={'lr': 0.0001}) -# Now that we set the `Problem` and the `Model`, we have just to train the model and use it for predict the test snapshots. +# Now that we have set the `Problem` and the `Model`, we have just to train the model and use it for predicting the test snapshots. -# In[35]: +# In[10]: trainer = Trainer( - solver=pinn_stokes, + solver=pod_nn_stokes, max_epochs=1000, batch_size=100, log_every_n_steps=5, @@ -161,47 +231,69 @@ trainer = Trainer( trainer.train() -# Done! Now the computational expensive part is over, we can load in future the model to infer new parameters (simply loading the checkpoint file automatically created by `Lightning`) or test its performances. We measure the relative error for the training and test datasets, printing the mean one. +# Done! Now that the computational expensive part is over, we can load in future the model to infer new parameters (simply loading the checkpoint file automatically created by `Lightning`) or test its performances. We measure the relative error for the training and test datasets, printing the mean one. -# In[36]: +# In[11]: -u_test_pred = pinn_stokes(p_test) -u_train_pred = pinn_stokes(p_train) +u_test_nn = pod_nn_stokes(p_test) +u_train_nn = pod_nn_stokes(p_train) -relative_error_train = torch.norm(u_train_pred - u_train)/torch.norm(u_train) -relative_error_test = torch.norm(u_test_pred - u_test)/torch.norm(u_test) +relative_error_train = torch.norm(u_train_nn - u_train)/torch.norm(u_train) +relative_error_test = torch.norm(u_test_nn - u_test)/torch.norm(u_test) -print('Error summary:') +print('Error summary for POD-NN model:') print(f' Train: {relative_error_train.item():e}') print(f' Test: {relative_error_test.item():e}') -# We can of course also plot the solutions predicted by the `PODNN` model, comparing them to the original ones. We can note here some differences, especially for low velocities, but improvements can be accomplished thanks to longer training. +# ## POD-RBF vs POD-NN -# In[37]: +# We can of course also plot the solutions predicted by the `PODRBF` and by the `PODNN` model, comparing them to the original ones. We can note here, in the `PODNN` model and for low velocities, some differences, but improvements can be accomplished thanks to longer training. + +# In[12]: -idx = torch.randint(0, len(u_test_pred), (4,)) -u_idx = pinn_stokes(p_test[idx]) +idx = torch.randint(0, len(u_test), (4,)) +u_idx_rbf = pod_rbf(p_test[idx]) +u_idx_nn = pod_nn_stokes(p_test[idx]) + import numpy as np import matplotlib -fig, axs = plt.subplots(3, 4, figsize=(14, 9)) +import matplotlib.pyplot as plt -relative_error = np.abs(u_test[idx] - u_idx.detach()) -relative_error = np.where(u_test[idx] < 1e-7, 1e-7, relative_error/u_test[idx]) - -for i, (idx_, u_, err_) in enumerate(zip(idx, u_idx, relative_error)): - cm = axs[0, i].tricontourf(dataset.triang, u_.detach()) +fig, axs = plt.subplots(5, 4, figsize=(14, 9)) + +relative_error_rbf = np.abs(u_test[idx] - u_idx_rbf.detach()) +relative_error_rbf = np.where(u_test[idx] < 1e-7, 1e-7, relative_error_rbf/u_test[idx]) + +relative_error_nn = np.abs(u_test[idx] - u_idx_nn.detach()) +relative_error_nn = np.where(u_test[idx] < 1e-7, 1e-7, relative_error_nn/u_test[idx]) + +for i, (idx_, rbf_, nn_, rbf_err_, nn_err_) in enumerate( + zip(idx, u_idx_rbf, u_idx_nn, relative_error_rbf, relative_error_nn)): axs[0, i].set_title(f'$\mu$ = {p_test[idx_].item():.2f}') - plt.colorbar(cm) - cm = axs[1, i].tricontourf(dataset.triang, u_test[idx_].flatten()) - plt.colorbar(cm) + cm = axs[0, i].tricontourf(dataset.triang, rbf_.detach()) # POD-RBF prediction + plt.colorbar(cm, ax=axs[0, i]) - cm = axs[2, i].tripcolor(dataset.triang, err_, norm=matplotlib.colors.LogNorm()) - plt.colorbar(cm) + cm = axs[1, i].tricontourf(dataset.triang, nn_.detach()) # POD-NN prediction + plt.colorbar(cm, ax=axs[1, i]) + cm = axs[2, i].tricontourf(dataset.triang, u_test[idx_].flatten()) # Truth + plt.colorbar(cm, ax=axs[2, i]) + + cm = axs[3, i].tripcolor(dataset.triang, rbf_err_, norm=matplotlib.colors.LogNorm()) # Error for POD-RBF + plt.colorbar(cm, ax=axs[3, i]) + + cm = axs[4, i].tripcolor(dataset.triang, nn_err_, norm=matplotlib.colors.LogNorm()) # Error for POD-NN + plt.colorbar(cm, ax=axs[4, i]) plt.show() + +# In[ ]: + + + +