PBC Layer (#252)
* update docs/tests * tutorial and device fix --------- Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local> Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it> Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.lan> Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.Home>
This commit is contained in:
@@ -68,6 +68,7 @@ Layers
|
|||||||
Spectral convolution <layers/spectral.rst>
|
Spectral convolution <layers/spectral.rst>
|
||||||
Fourier layers <layers/fourier.rst>
|
Fourier layers <layers/fourier.rst>
|
||||||
Continuous convolution <layers/convolution.rst>
|
Continuous convolution <layers/convolution.rst>
|
||||||
|
Coordinates embeddings <layers/embedding.rst>
|
||||||
|
|
||||||
|
|
||||||
Equations and Operators
|
Equations and Operators
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ Physics Informed Neural Networks
|
|||||||
Two dimensional Poisson problem using Extra Features Learning<tutorials/tutorial2/tutorial.rst>
|
Two dimensional Poisson problem using Extra Features Learning<tutorials/tutorial2/tutorial.rst>
|
||||||
Two dimensional Wave problem with hard constraint<tutorials/tutorial3/tutorial.rst>
|
Two dimensional Wave problem with hard constraint<tutorials/tutorial3/tutorial.rst>
|
||||||
Resolution of a 2D Poisson inverse problem<tutorials/tutorial7/tutorial.rst>
|
Resolution of a 2D Poisson inverse problem<tutorials/tutorial7/tutorial.rst>
|
||||||
|
Periodic Boundary Conditions for Helmotz Equation<tutorials/tutorial9/tutorial.rst>
|
||||||
|
|
||||||
Neural Operator Learning
|
Neural Operator Learning
|
||||||
------------------------
|
------------------------
|
||||||
|
|||||||
8
docs/source/_rst/layers/embedding.rst
Normal file
8
docs/source/_rst/layers/embedding.rst
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
Coordinates embeddings
|
||||||
|
======================
|
||||||
|
.. currentmodule:: pina.model.layers.embedding
|
||||||
|
|
||||||
|
.. autoclass:: PBCEmbedding
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
15
docs/source/_rst/layers/pod.rst
Normal file
15
docs/source/_rst/layers/pod.rst
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
Spectral Convolution
|
||||||
|
======================
|
||||||
|
.. currentmodule:: pina.model.layers.spectral
|
||||||
|
|
||||||
|
.. autoclass:: SpectralConvBlock1D
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. autoclass:: SpectralConvBlock2D
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
|
|
||||||
|
.. autoclass:: SpectralConvBlock3D
|
||||||
|
:members:
|
||||||
|
:show-inheritance:
|
||||||
226
docs/source/_rst/tutorials/tutorial9/tutorial.rst
Normal file
226
docs/source/_rst/tutorials/tutorial9/tutorial.rst
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
Tutorial: One dimensional Helmotz equation using Periodic Boundary Conditions
|
||||||
|
=============================================================================
|
||||||
|
|
||||||
|
This tutorial presents how to solve with Physics-Informed Neural
|
||||||
|
Networks (PINNs) a one dimensional Helmotz equation with periodic
|
||||||
|
boundary conditions (PBC). We will train with standard PINN’s training
|
||||||
|
by augmenting the input with periodic expasion as presented in `An
|
||||||
|
expert’s guide to training physics-informed neural
|
||||||
|
networks <https://arxiv.org/abs/2308.08468>`__.
|
||||||
|
|
||||||
|
First of all, some useful imports.
|
||||||
|
|
||||||
|
.. code:: ipython3
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from pina import Condition, Plotter
|
||||||
|
from pina.problem import SpatialProblem
|
||||||
|
from pina.operators import laplacian
|
||||||
|
from pina.model import FeedForward
|
||||||
|
from pina.model.layers import PeriodicBoundaryEmbedding # The PBC module
|
||||||
|
from pina.solvers import PINN
|
||||||
|
from pina.trainer import Trainer
|
||||||
|
from pina.geometry import CartesianDomain
|
||||||
|
from pina.equation import Equation
|
||||||
|
|
||||||
|
The problem definition
|
||||||
|
----------------------
|
||||||
|
|
||||||
|
The one-dimensional Helmotz problem is mathematically written as:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
|
||||||
|
|
||||||
|
\begin{cases}
|
||||||
|
\frac{d^2}{dx^2}u(x) - \lambda u(x) -f(x) &= 0 \quad x\in(0,2)\\
|
||||||
|
u^{(m)}(x=0) - u^{(m)}(x=2) &= 0 \quad m\in[0, 1, \cdots]\\
|
||||||
|
\end{cases}
|
||||||
|
|
||||||
|
In this case we are asking the solution to be :math:`C^{\infty}`
|
||||||
|
periodic with period :math:`2`, on the inifite domain
|
||||||
|
:math:`x\in(-\infty, \infty)`. Notice that the classical PINN would need
|
||||||
|
inifinite conditions to evaluate the PBC loss function, one for each
|
||||||
|
derivative, which is of course infeasable… A possible solution,
|
||||||
|
diverging from the original PINN formulation, is to use *coordinates
|
||||||
|
augmentation*. In coordinates augmentation you seek for a coordinates
|
||||||
|
transformation :math:`v` such that :math:`x\rightarrow v(x)` such that
|
||||||
|
the periodicity condition $ u^{(m)}(x=0) - u^{(m)}(x=2) = 0
|
||||||
|
:raw-latex:`\quad `m:raw-latex:`\in[0, 1, \cdots] `$ is satisfied.
|
||||||
|
|
||||||
|
For demonstration porpuses the problem specifics are
|
||||||
|
:math:`\lambda=-10\pi^2`, and
|
||||||
|
:math:`f(x)=-6\pi^2\sin(3\pi x)\cos(\pi x)` which gives a solution that
|
||||||
|
can be computed analytically :math:`u(x) = \sin(\pi x)\cos(3\pi x)`.
|
||||||
|
|
||||||
|
.. code:: ipython3
|
||||||
|
|
||||||
|
class Helmotz(SpatialProblem):
|
||||||
|
output_variables = ['u']
|
||||||
|
spatial_domain = CartesianDomain({'x': [0, 2]})
|
||||||
|
|
||||||
|
def helmotz_equation(input_, output_):
|
||||||
|
x = input_.extract('x')
|
||||||
|
u_xx = laplacian(output_, input_, components=['u'], d=['x'])
|
||||||
|
f = - 6.*torch.pi**2 * torch.sin(3*torch.pi*x)*torch.cos(torch.pi*x)
|
||||||
|
lambda_ = - 10. * torch.pi ** 2
|
||||||
|
return u_xx - lambda_ * output_ - f
|
||||||
|
|
||||||
|
# here we write the problem conditions
|
||||||
|
conditions = {
|
||||||
|
'D': Condition(location=spatial_domain,
|
||||||
|
equation=Equation(helmotz_equation)),
|
||||||
|
}
|
||||||
|
|
||||||
|
def helmotz_sol(self, pts):
|
||||||
|
return torch.sin(torch.pi * pts) * torch.cos(3. * torch.pi * pts)
|
||||||
|
|
||||||
|
truth_solution = helmotz_sol
|
||||||
|
|
||||||
|
problem = Helmotz()
|
||||||
|
|
||||||
|
# let's discretise the domain
|
||||||
|
problem.discretise_domain(200, 'grid', locations=['D'])
|
||||||
|
|
||||||
|
As usual the Helmotz problem is written in **PINA** code as a class. The
|
||||||
|
equations are written as ``conditions`` that should be satisfied in the
|
||||||
|
corresponding domains. The ``truth_solution`` is the exact solution
|
||||||
|
which will be compared with the predicted one. We used latin hypercube
|
||||||
|
sampling for choosing the collocation points.
|
||||||
|
|
||||||
|
Solving the problem with a Periodic Network
|
||||||
|
-------------------------------------------
|
||||||
|
|
||||||
|
Any :math:`\mathcal{C}^{\infty}` periodic function
|
||||||
|
:math:`u : \mathbb{R} \rightarrow \mathbb{R}` with period
|
||||||
|
:math:`L\in\mathbb{N}` can be constructed by composition of an arbitrary
|
||||||
|
smooth function :math:`f : \mathbb{R}^n \rightarrow \mathbb{R}` and a
|
||||||
|
given smooth periodic function
|
||||||
|
:math:`v : \mathbb{R} \rightarrow \mathbb{R}^n` with period :math:`L`,
|
||||||
|
that is :math:`u(x) = f(v(x))`. The formulation is generalizable for
|
||||||
|
arbitrary dimension, see `A method for representing periodic functions
|
||||||
|
and enforcing exactly periodic boundary conditions with deep neural
|
||||||
|
networks <https://arxiv.org/pdf/2007.07442>`__.
|
||||||
|
|
||||||
|
In our case, we rewrite
|
||||||
|
:math:`v(x) = \left[1, \cos\left(\frac{2\pi}{L} x\right), \sin\left(\frac{2\pi}{L} x\right)\right]`,
|
||||||
|
i.e the coordinates augmentation, and
|
||||||
|
:math:`f(\cdot) = NN_{\theta}(\cdot)` i.e. a neural network. The
|
||||||
|
resulting neural network obtained by composing :math:`f` with :math:`v`
|
||||||
|
gives the PINN approximate solution, that is
|
||||||
|
:math:`u(x) \approx u_{\theta}(x)=NN_{\theta}(v(x))`.
|
||||||
|
|
||||||
|
In **PINA** this translates in using the ``PeriodicBoundaryEmbedding`` layer for
|
||||||
|
:math:`v`, and any ``pina.model`` for :math:`NN_{\theta}`. Let’s see it
|
||||||
|
in action!
|
||||||
|
|
||||||
|
.. code:: ipython3
|
||||||
|
|
||||||
|
# we encapsulate all modules in a torch.nn.Sequential container
|
||||||
|
model = torch.nn.Sequential(PeriodicBoundaryEmbedding(input_dimension=1,
|
||||||
|
periods=2),
|
||||||
|
FeedForward(input_dimensions=3, # output of PeriodicBoundaryEmbedding = 3 * input_dimension
|
||||||
|
output_dimensions=1,
|
||||||
|
layers=[10, 10]))
|
||||||
|
|
||||||
|
As simple as that! Notice in higher dimension you can specify different
|
||||||
|
periods for all dimensions using a dictionary,
|
||||||
|
e.g. ``periods={'x':2, 'y':3, ...}`` would indicate a periodicity of
|
||||||
|
:math:`2` in :math:`x`, :math:`3` in :math:`y`, and so on…
|
||||||
|
|
||||||
|
We will now sole the problem as usually with the ``PINN`` and
|
||||||
|
``Trainer`` class.
|
||||||
|
|
||||||
|
.. code:: ipython3
|
||||||
|
|
||||||
|
pinn = PINN(problem=problem, model=model)
|
||||||
|
trainer = Trainer(pinn, max_epochs=5000, accelerator='cpu', enable_model_summary=False) # we train on CPU and avoid model summary at beginning of training (optional)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
.. parsed-literal::
|
||||||
|
|
||||||
|
GPU available: True (mps), used: False
|
||||||
|
TPU available: False, using: 0 TPU cores
|
||||||
|
IPU available: False, using: 0 IPUs
|
||||||
|
HPU available: False, using: 0 HPUs
|
||||||
|
|
||||||
|
.. parsed-literal::
|
||||||
|
|
||||||
|
`Trainer.fit` stopped: `max_epochs=5000` reached.
|
||||||
|
|
||||||
|
|
||||||
|
.. parsed-literal::
|
||||||
|
|
||||||
|
Epoch 4999: 100%|██████████| 1/1 [00:00<00:00, 155.47it/s, v_num=20, D_loss=0.0123, mean_loss=0.0123]
|
||||||
|
|
||||||
|
|
||||||
|
We are going to plot the solution now!
|
||||||
|
|
||||||
|
.. code:: ipython3
|
||||||
|
|
||||||
|
pl = Plotter()
|
||||||
|
pl.plot(pinn)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
.. image:: tutorial_files/tutorial_11_0.png
|
||||||
|
|
||||||
|
|
||||||
|
Great, they overlap perfectly! This seeams a good result, considering
|
||||||
|
the simple neural network used to some this (complex) problem. We will
|
||||||
|
now test the neural network on the domain :math:`[-4, 4]` without
|
||||||
|
retraining. In principle the periodicity should be present since the
|
||||||
|
:math:`v` function ensures the periodicity in :math:`(-\infty, \infty)`.
|
||||||
|
|
||||||
|
.. code:: ipython3
|
||||||
|
|
||||||
|
# plotting solution
|
||||||
|
with torch.no_grad():
|
||||||
|
# Notice here we put [-4, 4]!!!
|
||||||
|
new_domain = CartesianDomain({'x' : [0, 4]})
|
||||||
|
x = new_domain.sample(1000, mode='grid')
|
||||||
|
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||||||
|
# Plot 1
|
||||||
|
axes[0].plot(x, problem.truth_solution(x), label=r'$u(x)$', color='blue')
|
||||||
|
axes[0].set_title(r'True solution $u(x)$')
|
||||||
|
axes[0].legend(loc="upper right")
|
||||||
|
# Plot 2
|
||||||
|
axes[1].plot(x, pinn(x), label=r'$u_{\theta}(x)$', color='green')
|
||||||
|
axes[1].set_title(r'PINN solution $u_{\theta}(x)$')
|
||||||
|
axes[1].legend(loc="upper right")
|
||||||
|
# Plot 3
|
||||||
|
diff = torch.abs(problem.truth_solution(x) - pinn(x))
|
||||||
|
axes[2].plot(x, diff, label=r'$|u(x) - u_{\theta}(x)|$', color='red')
|
||||||
|
axes[2].set_title(r'Absolute difference $|u(x) - u_{\theta}(x)|$')
|
||||||
|
axes[2].legend(loc="upper right")
|
||||||
|
# Adjust layout
|
||||||
|
plt.tight_layout()
|
||||||
|
# Show the plots
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
.. image:: tutorial_files/tutorial_13_0.png
|
||||||
|
|
||||||
|
|
||||||
|
It is pretty clear that the network is periodic, with also the error
|
||||||
|
following a periodic pattern. Obviusly a longer training, and a more
|
||||||
|
expressive neural network could improve the results!
|
||||||
|
|
||||||
|
What’s next?
|
||||||
|
------------
|
||||||
|
|
||||||
|
Nice you have completed the one dimensional Helmotz tutorial of
|
||||||
|
**PINA**! There are multiple directions you can go now:
|
||||||
|
|
||||||
|
1. Train the network for longer or with different layer sizes and assert
|
||||||
|
the finaly accuracy
|
||||||
|
|
||||||
|
2. Apply the ``PeriodicBoundaryEmbedding`` layer for a time-dependent problem (see
|
||||||
|
reference in the documentation)
|
||||||
|
|
||||||
|
3. Exploit extrafeature training ?
|
||||||
|
|
||||||
|
4. Many more…
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 59 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 90 KiB |
@@ -9,6 +9,7 @@ __all__ = [
|
|||||||
"FourierBlock2D",
|
"FourierBlock2D",
|
||||||
"FourierBlock3D",
|
"FourierBlock3D",
|
||||||
"PODBlock",
|
"PODBlock",
|
||||||
|
"PeriodicBoundaryEmbedding"
|
||||||
]
|
]
|
||||||
|
|
||||||
from .convolution_2d import ContinuousConvBlock
|
from .convolution_2d import ContinuousConvBlock
|
||||||
@@ -20,3 +21,4 @@ from .spectral import (
|
|||||||
)
|
)
|
||||||
from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
|
from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
|
||||||
from .pod import PODBlock
|
from .pod import PODBlock
|
||||||
|
from .embedding import PeriodicBoundaryEmbedding
|
||||||
|
|||||||
142
pina/model/layers/embedding.py
Normal file
142
pina/model/layers/embedding.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
""" Periodic Boundary Embedding modulus. """
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from pina.utils import check_consistency
|
||||||
|
|
||||||
|
|
||||||
|
class PeriodicBoundaryEmbedding(torch.nn.Module):
|
||||||
|
r"""
|
||||||
|
Imposing hard constraint periodic boundary conditions by embedding the
|
||||||
|
input.
|
||||||
|
|
||||||
|
A periodic function :math:`u:\mathbb{R}^{\rm{in}}
|
||||||
|
\rightarrow\mathbb{R}^{\rm{out}}` periodic in the spatial
|
||||||
|
coordinates :math:`\mathbf{x}` with periods :math:`\mathbf{L}` is such that:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
u(\mathbf{x}) = u(\mathbf{x} + n \mathbf{L})\;\;
|
||||||
|
\forall n\in\mathbb{N}.
|
||||||
|
|
||||||
|
The :meth:`PeriodicBoundaryEmbedding` augments the input such that the periodic conditons
|
||||||
|
is guarantee. The input is augmented by the following formula:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\mathbf{x} \rightarrow \tilde{\mathbf{x}} = \left[1,
|
||||||
|
\cos\left(\frac{2\pi}{L_1} x_1 \right),
|
||||||
|
\sin\left(\frac{2\pi}{L_1}x_1\right), \cdots,
|
||||||
|
\cos\left(\frac{2\pi}{L_{\rm{in}}}x_{\rm{in}}\right),
|
||||||
|
\sin\left(\frac{2\pi}{L_{\rm{in}}}x_{\rm{in}}\right)\right],
|
||||||
|
|
||||||
|
where :math:`\text{dim}(\tilde{\mathbf{x}}) = 3\text{dim}(\mathbf{x})`.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
**Original reference**:
|
||||||
|
1. Dong, Suchuan, and Naxian Ni (2021). *A method for representing
|
||||||
|
periodic functions and enforcing exactly periodic boundary
|
||||||
|
conditions with deep neural networks*. Journal of Computational
|
||||||
|
Physics 435, 110242.
|
||||||
|
DOI: `10.1016/j.jcp.2021.110242.
|
||||||
|
<https://doi.org/10.1016/j.jcp.2021.110242>`_
|
||||||
|
2. Wang, S., Sankaran, S., Wang, H., & Perdikaris, P. (2023). *An
|
||||||
|
expert's guide to training physics-informed neural networks*.
|
||||||
|
DOI: `arXiv preprint arXiv:2308.0846.
|
||||||
|
<https://arxiv.org/abs/2308.08468>`_
|
||||||
|
.. warning::
|
||||||
|
The embedding is a truncated fourier expansion, and only ensures
|
||||||
|
function PBC and not for its derivatives. Ensuring approximate
|
||||||
|
periodicity in
|
||||||
|
the derivatives of :math:`u` can be done, and extensive
|
||||||
|
tests have shown (also in the reference papers) that this implementation
|
||||||
|
can correctly compute the PBC on the derivatives up to the order
|
||||||
|
:math:`\sim 2,3`, while it is not guarantee the periodicity for
|
||||||
|
:math:`>3`. The PINA code is tested only for function PBC and not for
|
||||||
|
its derivatives.
|
||||||
|
"""
|
||||||
|
def __init__(self, input_dimension, periods, output_dimension=None):
|
||||||
|
"""
|
||||||
|
:param int input_dimension: The dimension of the input tensor, it can
|
||||||
|
be checked with `tensor.ndim` method.
|
||||||
|
:param float | int | dict periods: The periodicity in each dimension for
|
||||||
|
the input data. If ``float`` or ``int`` is passed,
|
||||||
|
the period is assumed constant for all the dimensions of the data.
|
||||||
|
If a ``dict`` is passed the `dict.values` represent periods,
|
||||||
|
while the ``dict.keys`` represent the dimension where the
|
||||||
|
periodicity is applied. The `dict.keys` can either be `int`
|
||||||
|
if working with ``torch.Tensor`` or ``str`` if
|
||||||
|
working with ``LabelTensor``.
|
||||||
|
:param int output_dimension: The dimension of the output after the
|
||||||
|
fourier embedding. If not ``None`` a ``torch.nn.Linear`` layer
|
||||||
|
is applied to the fourier embedding output to match the desired
|
||||||
|
dimensionality, default ``None``.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# check input consistency
|
||||||
|
check_consistency(periods, (float, int, dict))
|
||||||
|
check_consistency(input_dimension, int)
|
||||||
|
if output_dimension is not None:
|
||||||
|
check_consistency(output_dimension, int)
|
||||||
|
self._layer = torch.nn.Linear(input_dimension * 3, output_dimension)
|
||||||
|
else:
|
||||||
|
self._layer = torch.nn.Identity()
|
||||||
|
|
||||||
|
# checks on the periods
|
||||||
|
if isinstance(periods, dict):
|
||||||
|
if not all(isinstance(dim, (str, int)) and
|
||||||
|
isinstance(period, (float, int))
|
||||||
|
for dim, period in periods.items()):
|
||||||
|
raise TypeError('In dictionary periods, keys must be integers'
|
||||||
|
' or strings, and values must be float or int.')
|
||||||
|
self._period = periods
|
||||||
|
else:
|
||||||
|
self._period = {k: periods for k in range(input_dimension)}
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Forward pass to compute the periodic boundary conditions embedding.
|
||||||
|
|
||||||
|
:param torch.Tensor x: Input tensor.
|
||||||
|
:return: Fourier embeddings of the input.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
omega = torch.stack([torch.pi * 2. / torch.tensor([val],
|
||||||
|
device=x.device)
|
||||||
|
for val in self._period.values()],
|
||||||
|
dim=-1)
|
||||||
|
x = self._get_vars(x, list(self._period.keys()))
|
||||||
|
return self._layer(torch.cat([torch.ones_like(x),
|
||||||
|
torch.cos(omega * x),
|
||||||
|
torch.sin(omega * x)], dim=-1))
|
||||||
|
|
||||||
|
def _get_vars(self, x, indeces):
|
||||||
|
"""
|
||||||
|
Get variables from input tensor ordered by specific indeces.
|
||||||
|
|
||||||
|
:param torch.Tensor x: The input tensor to extract.
|
||||||
|
:param list[int] | list[str] indeces: List of indeces to extract.
|
||||||
|
:return: The extracted tensor given the indeces.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
|
if isinstance(indeces[0], str):
|
||||||
|
try:
|
||||||
|
return x.extract(indeces)
|
||||||
|
except AttributeError:
|
||||||
|
raise RuntimeError(
|
||||||
|
'Not possible to extract input variables from tensor.'
|
||||||
|
' Ensure that the passed tensor is a LabelTensor or'
|
||||||
|
' pass list of integers to extract variables. For'
|
||||||
|
' more information refer to warning in the documentation.')
|
||||||
|
elif isinstance(indeces[0], int):
|
||||||
|
return x[..., indeces]
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
'Not able to extract right indeces for tensor.'
|
||||||
|
' For more information refer to warning in the documentation.')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def period(self):
|
||||||
|
"""
|
||||||
|
The period of the periodic function to approximate.
|
||||||
|
"""
|
||||||
|
return self._period
|
||||||
99
tests/test_layers/test_embedding.py
Normal file
99
tests/test_layers/test_embedding.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
import torch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from pina.model.layers import PeriodicBoundaryEmbedding
|
||||||
|
from pina import LabelTensor
|
||||||
|
|
||||||
|
def check_same_columns(tensor):
|
||||||
|
# Get the first column
|
||||||
|
first_column = tensor[0]
|
||||||
|
# Compare each column with the first column
|
||||||
|
all_same = torch.allclose(tensor, first_column)
|
||||||
|
return all_same
|
||||||
|
|
||||||
|
def grad(u, x):
|
||||||
|
"""
|
||||||
|
Compute the first derivative of u with respect to x.
|
||||||
|
"""
|
||||||
|
return torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u),
|
||||||
|
create_graph=True, allow_unused=True,
|
||||||
|
retain_graph=True)[0]
|
||||||
|
|
||||||
|
def test_constructor():
|
||||||
|
PeriodicBoundaryEmbedding(input_dimension=1, periods=2)
|
||||||
|
PeriodicBoundaryEmbedding(input_dimension=1, periods={'x': 3, 'y' : 4})
|
||||||
|
PeriodicBoundaryEmbedding(input_dimension=1, periods={0: 3, 1 : 4})
|
||||||
|
PeriodicBoundaryEmbedding(input_dimension=1, periods=2, output_dimension=10)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
PeriodicBoundaryEmbedding()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
PeriodicBoundaryEmbedding(input_dimension=1., periods=1)
|
||||||
|
PeriodicBoundaryEmbedding(input_dimension=1, periods=1, output_dimension=1.)
|
||||||
|
PeriodicBoundaryEmbedding(input_dimension=1, periods={'x':'x'})
|
||||||
|
PeriodicBoundaryEmbedding(input_dimension=1, periods={0:'x'})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("period", [1, 4, 10])
|
||||||
|
@pytest.mark.parametrize("input_dimension", [1, 2, 3])
|
||||||
|
def test_forward_same_period(input_dimension, period):
|
||||||
|
func = torch.nn.Sequential(
|
||||||
|
PeriodicBoundaryEmbedding(input_dimension=input_dimension,
|
||||||
|
output_dimension=60, periods=period),
|
||||||
|
torch.nn.Tanh(),
|
||||||
|
torch.nn.Linear(60, 60),
|
||||||
|
torch.nn.Tanh(),
|
||||||
|
torch.nn.Linear(60, 1)
|
||||||
|
)
|
||||||
|
# coordinates
|
||||||
|
x = period * torch.tensor([[0.],[1.]])
|
||||||
|
if input_dimension == 2:
|
||||||
|
x = torch.cartesian_prod(x.flatten(),x.flatten())
|
||||||
|
elif input_dimension == 3:
|
||||||
|
x = torch.cartesian_prod(x.flatten(),x.flatten(),x.flatten())
|
||||||
|
x.requires_grad = True
|
||||||
|
# output
|
||||||
|
f = func(x)
|
||||||
|
assert check_same_columns(f)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_forward_same_period_labels():
|
||||||
|
func = torch.nn.Sequential(
|
||||||
|
PeriodicBoundaryEmbedding(input_dimension=2,
|
||||||
|
output_dimension=60, periods={'x':1, 'y':2}),
|
||||||
|
torch.nn.Tanh(),
|
||||||
|
torch.nn.Linear(60, 60),
|
||||||
|
torch.nn.Tanh(),
|
||||||
|
torch.nn.Linear(60, 1)
|
||||||
|
)
|
||||||
|
# coordinates
|
||||||
|
tensor = torch.tensor([[0., 0.], [0., 2.], [1., 0.], [1., 2.]])
|
||||||
|
with pytest.raises(RuntimeError):
|
||||||
|
func(tensor)
|
||||||
|
tensor = tensor.as_subclass(LabelTensor)
|
||||||
|
tensor.labels = ['x', 'y']
|
||||||
|
tensor.requires_grad = True
|
||||||
|
# output
|
||||||
|
f = func(tensor)
|
||||||
|
assert check_same_columns(f)
|
||||||
|
|
||||||
|
def test_forward_same_period_index():
|
||||||
|
func = torch.nn.Sequential(
|
||||||
|
PeriodicBoundaryEmbedding(input_dimension=2,
|
||||||
|
output_dimension=60, periods={0:1, 1:2}),
|
||||||
|
torch.nn.Tanh(),
|
||||||
|
torch.nn.Linear(60, 60),
|
||||||
|
torch.nn.Tanh(),
|
||||||
|
torch.nn.Linear(60, 1)
|
||||||
|
)
|
||||||
|
# coordinates
|
||||||
|
tensor = torch.tensor([[0., 0.], [0., 2.], [1., 0.], [1., 2.]])
|
||||||
|
tensor.requires_grad = True
|
||||||
|
# output
|
||||||
|
f = func(tensor)
|
||||||
|
assert check_same_columns(f)
|
||||||
|
tensor = tensor.as_subclass(LabelTensor)
|
||||||
|
tensor.labels = ['x', 'y']
|
||||||
|
# output
|
||||||
|
f = func(tensor)
|
||||||
|
assert check_same_columns(f)
|
||||||
1
tutorials/README.md
vendored
1
tutorials/README.md
vendored
@@ -16,6 +16,7 @@ Building custom geometries with PINA `Location` class|[[.ipynb](tutorial6/tutori
|
|||||||
Two dimensional Poisson problem using Extra Features Learning |[[.ipynb](tutorial2/tutorial.ipynb), [.py](tutorial2/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial2/tutorial.html)]|
|
Two dimensional Poisson problem using Extra Features Learning |[[.ipynb](tutorial2/tutorial.ipynb), [.py](tutorial2/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial2/tutorial.html)]|
|
||||||
Two dimensional Wave problem with hard constraint |[[.ipynb](tutorial3/tutorial.ipynb), [.py](tutorial3/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial3/tutorial.html)]|
|
Two dimensional Wave problem with hard constraint |[[.ipynb](tutorial3/tutorial.ipynb), [.py](tutorial3/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial3/tutorial.html)]|
|
||||||
Resolution of a 2D Poisson inverse problem |[[.ipynb](tutorial7/tutorial.ipynb), [.py](tutorial7/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial7/tutorial.html)]|
|
Resolution of a 2D Poisson inverse problem |[[.ipynb](tutorial7/tutorial.ipynb), [.py](tutorial7/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial7/tutorial.html)]|
|
||||||
|
Periodic Boundary Conditions for Helmotz Equation |[[.ipynb](tutorial9/tutorial.ipynb), [.py](tutorial9/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial9/tutorial.html)]|
|
||||||
|
|
||||||
## Neural Operator Learning
|
## Neural Operator Learning
|
||||||
| Description | Tutorial |
|
| Description | Tutorial |
|
||||||
|
|||||||
298
tutorials/tutorial9/tutorial.ipynb
vendored
Normal file
298
tutorials/tutorial9/tutorial.ipynb
vendored
Normal file
File diff suppressed because one or more lines are too long
191
tutorials/tutorial9/tutorial.py
vendored
Normal file
191
tutorials/tutorial9/tutorial.py
vendored
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
# # Tutorial: One dimensional Helmotz equation using Periodic Boundary Conditions
|
||||||
|
# This tutorial presents how to solve with Physics-Informed Neural Networks (PINNs)
|
||||||
|
# a one dimensional Helmotz equation with periodic boundary conditions (PBC).
|
||||||
|
# We will train with standard PINN's training by augmenting the input with
|
||||||
|
# periodic expasion as presented in [*An expert’s guide to training
|
||||||
|
# physics-informed neural networks*](
|
||||||
|
# https://arxiv.org/abs/2308.08468).
|
||||||
|
#
|
||||||
|
# First of all, some useful imports.
|
||||||
|
|
||||||
|
# In[1]:
|
||||||
|
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from pina import Condition, Plotter
|
||||||
|
from pina.problem import SpatialProblem
|
||||||
|
from pina.operators import laplacian
|
||||||
|
from pina.model import FeedForward
|
||||||
|
from pina.model.layers import PeriodicBoundaryEmbedding # The PBC module
|
||||||
|
from pina.solvers import PINN
|
||||||
|
from pina.trainer import Trainer
|
||||||
|
from pina.geometry import CartesianDomain
|
||||||
|
from pina.equation import Equation
|
||||||
|
|
||||||
|
|
||||||
|
# ## The problem definition
|
||||||
|
#
|
||||||
|
# The one-dimensional Helmotz problem is mathematically written as:
|
||||||
|
# $$
|
||||||
|
# \begin{cases}
|
||||||
|
# \frac{d^2}{dx^2}u(x) - \lambda u(x) -f(x) &= 0 \quad x\in(0,2)\\
|
||||||
|
# u^{(m)}(x=0) - u^{(m)}(x=2) &= 0 \quad m\in[0, 1, \cdots]\\
|
||||||
|
# \end{cases}
|
||||||
|
# $$
|
||||||
|
# In this case we are asking the solution to be $C^{\infty}$ periodic with
|
||||||
|
# period $2$, on the inifite domain $x\in(-\infty, \infty)$. Notice that the
|
||||||
|
# classical PINN would need inifinite conditions to evaluate the PBC loss function,
|
||||||
|
# one for each derivative, which is of course infeasable...
|
||||||
|
# A possible solution, diverging from the original PINN formulation,
|
||||||
|
# is to use *coordinates augmentation*. In coordinates augmentation you seek for
|
||||||
|
# a coordinates transformation $v$ such that $x\rightarrow v(x)$ such that
|
||||||
|
# the periodicity condition $ u^{(m)}(x=0) - u^{(m)}(x=2) = 0 \quad m\in[0, 1, \cdots] $ is
|
||||||
|
# satisfied.
|
||||||
|
#
|
||||||
|
# For demonstration porpuses the problem specifics are $\lambda=-10\pi^2$,
|
||||||
|
# and $f(x)=-6\pi^2\sin(3\pi x)\cos(\pi x)$ which gives a solution that can be
|
||||||
|
# computed analytically $u(x) = \sin(\pi x)\cos(3\pi x)$.
|
||||||
|
|
||||||
|
# In[2]:
|
||||||
|
|
||||||
|
|
||||||
|
class Helmotz(SpatialProblem):
|
||||||
|
output_variables = ['u']
|
||||||
|
spatial_domain = CartesianDomain({'x': [0, 2]})
|
||||||
|
|
||||||
|
def helmotz_equation(input_, output_):
|
||||||
|
x = input_.extract('x')
|
||||||
|
u_xx = laplacian(output_, input_, components=['u'], d=['x'])
|
||||||
|
f = - 6.*torch.pi**2 * torch.sin(3*torch.pi*x)*torch.cos(torch.pi*x)
|
||||||
|
lambda_ = - 10. * torch.pi ** 2
|
||||||
|
return u_xx - lambda_ * output_ - f
|
||||||
|
|
||||||
|
# here we write the problem conditions
|
||||||
|
conditions = {
|
||||||
|
'D': Condition(location=spatial_domain,
|
||||||
|
equation=Equation(helmotz_equation)),
|
||||||
|
}
|
||||||
|
|
||||||
|
def helmotz_sol(self, pts):
|
||||||
|
return torch.sin(torch.pi * pts) * torch.cos(3. * torch.pi * pts)
|
||||||
|
|
||||||
|
truth_solution = helmotz_sol
|
||||||
|
|
||||||
|
problem = Helmotz()
|
||||||
|
|
||||||
|
# let's discretise the domain
|
||||||
|
problem.discretise_domain(200, 'grid', locations=['D'])
|
||||||
|
|
||||||
|
|
||||||
|
# As usual the Helmotz problem is written in **PINA** code as a class.
|
||||||
|
# The equations are written as `conditions` that should be satisfied in the
|
||||||
|
# corresponding domains. The `truth_solution`
|
||||||
|
# is the exact solution which will be compared with the predicted one. We used
|
||||||
|
# latin hypercube sampling for choosing the collocation points.
|
||||||
|
|
||||||
|
# ## Solving the problem with a Periodic Network
|
||||||
|
|
||||||
|
# Any $\mathcal{C}^{\infty}$ periodic function
|
||||||
|
# $u : \mathbb{R} \rightarrow \mathbb{R}$ with period
|
||||||
|
# $L\in\mathbb{N}$ can be constructed by composition of an
|
||||||
|
# arbitrary smooth function $f : \mathbb{R}^n \rightarrow \mathbb{R}$ and a
|
||||||
|
# given smooth periodic function $v : \mathbb{R} \rightarrow \mathbb{R}^n$ with
|
||||||
|
# period $L$, that is $u(x) = f(v(x))$. The formulation is generalizable for
|
||||||
|
# arbitrary dimension, see [*A method for representing periodic functions and
|
||||||
|
# enforcing exactly periodic boundary conditions with
|
||||||
|
# deep neural networks*](https://arxiv.org/pdf/2007.07442).
|
||||||
|
#
|
||||||
|
# In our case, we rewrite
|
||||||
|
# $v(x) = \left[1, \cos\left(\frac{2\pi}{L} x\right),
|
||||||
|
# \sin\left(\frac{2\pi}{L} x\right)\right]$, i.e
|
||||||
|
# the coordinates augmentation, and $f(\cdot) = NN_{\theta}(\cdot)$ i.e. a neural
|
||||||
|
# network. The resulting neural network obtained by composing $f$ with $v$ gives
|
||||||
|
# the PINN approximate solution, that is
|
||||||
|
# $u(x) \approx u_{\theta}(x)=NN_{\theta}(v(x))$.
|
||||||
|
#
|
||||||
|
# In **PINA** this translates in using the `PeriodicBoundaryEmbedding` layer for $v$, and any
|
||||||
|
# `pina.model` for $NN_{\theta}$. Let's see it in action!
|
||||||
|
#
|
||||||
|
|
||||||
|
# In[3]:
|
||||||
|
|
||||||
|
|
||||||
|
# we encapsulate all modules in a torch.nn.Sequential container
|
||||||
|
model = torch.nn.Sequential(PeriodicBoundaryEmbedding(input_dimension=1,
|
||||||
|
periods=2),
|
||||||
|
FeedForward(input_dimensions=3, # output of PeriodicBoundaryEmbedding = 3 * input_dimension
|
||||||
|
output_dimensions=1,
|
||||||
|
layers=[10, 10]))
|
||||||
|
|
||||||
|
|
||||||
|
# As simple as that! Notice in higher dimension you can specify different periods
|
||||||
|
# for all dimensions using a dictionary, e.g. `periods={'x':2, 'y':3, ...}`
|
||||||
|
# would indicate a periodicity of $2$ in $x$, $3$ in $y$, and so on...
|
||||||
|
#
|
||||||
|
# We will now sole the problem as usually with the `PINN` and `Trainer` class.
|
||||||
|
|
||||||
|
# In[5]:
|
||||||
|
|
||||||
|
|
||||||
|
pinn = PINN(problem=problem, model=model)
|
||||||
|
trainer = Trainer(pinn, max_epochs=5000, accelerator='cpu', enable_model_summary=False) # we train on CPU and avoid model summary at beginning of training (optional)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
# We are going to plot the solution now!
|
||||||
|
|
||||||
|
# In[6]:
|
||||||
|
|
||||||
|
|
||||||
|
pl = Plotter()
|
||||||
|
pl.plot(pinn)
|
||||||
|
|
||||||
|
|
||||||
|
# Great, they overlap perfectly! This seeams a good result, considering the simple neural network used to some this (complex) problem. We will now test the neural network on the domain $[-4, 4]$ without retraining. In principle the periodicity should be present since the $v$ function ensures the periodicity in $(-\infty, \infty)$.
|
||||||
|
|
||||||
|
# In[7]:
|
||||||
|
|
||||||
|
|
||||||
|
# plotting solution
|
||||||
|
with torch.no_grad():
|
||||||
|
# Notice here we put [-4, 4]!!!
|
||||||
|
new_domain = CartesianDomain({'x' : [0, 4]})
|
||||||
|
x = new_domain.sample(1000, mode='grid')
|
||||||
|
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||||||
|
# Plot 1
|
||||||
|
axes[0].plot(x, problem.truth_solution(x), label=r'$u(x)$', color='blue')
|
||||||
|
axes[0].set_title(r'True solution $u(x)$')
|
||||||
|
axes[0].legend(loc="upper right")
|
||||||
|
# Plot 2
|
||||||
|
axes[1].plot(x, pinn(x), label=r'$u_{\theta}(x)$', color='green')
|
||||||
|
axes[1].set_title(r'PINN solution $u_{\theta}(x)$')
|
||||||
|
axes[1].legend(loc="upper right")
|
||||||
|
# Plot 3
|
||||||
|
diff = torch.abs(problem.truth_solution(x) - pinn(x))
|
||||||
|
axes[2].plot(x, diff, label=r'$|u(x) - u_{\theta}(x)|$', color='red')
|
||||||
|
axes[2].set_title(r'Absolute difference $|u(x) - u_{\theta}(x)|$')
|
||||||
|
axes[2].legend(loc="upper right")
|
||||||
|
# Adjust layout
|
||||||
|
plt.tight_layout()
|
||||||
|
# Show the plots
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
# It is pretty clear that the network is periodic, with also the error following a periodic pattern. Obviusly a longer training, and a more expressive neural network could improve the results!
|
||||||
|
#
|
||||||
|
# ## What's next?
|
||||||
|
#
|
||||||
|
# Nice you have completed the one dimensional Helmotz tutorial of **PINA**! There are multiple directions you can go now:
|
||||||
|
#
|
||||||
|
# 1. Train the network for longer or with different layer sizes and assert the finaly accuracy
|
||||||
|
#
|
||||||
|
# 2. Apply the `PeriodicBoundaryEmbedding` layer for a time-dependent problem (see reference in the documentation)
|
||||||
|
#
|
||||||
|
# 3. Exploit extrafeature training ?
|
||||||
|
#
|
||||||
|
# 4. Many more...
|
||||||
Reference in New Issue
Block a user