PBC Layer (#252)

* update docs/tests
* tutorial and device fix

---------

Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.local>
Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.lan>
Co-authored-by: Dario Coscia <dariocoscia@Dario-Coscia.Home>
This commit is contained in:
Dario Coscia
2024-03-01 18:15:45 +01:00
committed by GitHub
parent c92a2832d5
commit 4cfd90b904
13 changed files with 984 additions and 1 deletions

View File

@@ -68,6 +68,7 @@ Layers
Spectral convolution <layers/spectral.rst> Spectral convolution <layers/spectral.rst>
Fourier layers <layers/fourier.rst> Fourier layers <layers/fourier.rst>
Continuous convolution <layers/convolution.rst> Continuous convolution <layers/convolution.rst>
Coordinates embeddings <layers/embedding.rst>
Equations and Operators Equations and Operators

View File

@@ -21,7 +21,7 @@ Physics Informed Neural Networks
Two dimensional Poisson problem using Extra Features Learning<tutorials/tutorial2/tutorial.rst> Two dimensional Poisson problem using Extra Features Learning<tutorials/tutorial2/tutorial.rst>
Two dimensional Wave problem with hard constraint<tutorials/tutorial3/tutorial.rst> Two dimensional Wave problem with hard constraint<tutorials/tutorial3/tutorial.rst>
Resolution of a 2D Poisson inverse problem<tutorials/tutorial7/tutorial.rst> Resolution of a 2D Poisson inverse problem<tutorials/tutorial7/tutorial.rst>
Periodic Boundary Conditions for Helmotz Equation<tutorials/tutorial9/tutorial.rst>
Neural Operator Learning Neural Operator Learning
------------------------ ------------------------

View File

@@ -0,0 +1,8 @@
Coordinates embeddings
======================
.. currentmodule:: pina.model.layers.embedding
.. autoclass:: PBCEmbedding
:members:
:show-inheritance:

View File

@@ -0,0 +1,15 @@
Spectral Convolution
======================
.. currentmodule:: pina.model.layers.spectral
.. autoclass:: SpectralConvBlock1D
:members:
:show-inheritance:
.. autoclass:: SpectralConvBlock2D
:members:
:show-inheritance:
.. autoclass:: SpectralConvBlock3D
:members:
:show-inheritance:

View File

@@ -0,0 +1,226 @@
Tutorial: One dimensional Helmotz equation using Periodic Boundary Conditions
=============================================================================
This tutorial presents how to solve with Physics-Informed Neural
Networks (PINNs) a one dimensional Helmotz equation with periodic
boundary conditions (PBC). We will train with standard PINNs training
by augmenting the input with periodic expasion as presented in `An
experts guide to training physics-informed neural
networks <https://arxiv.org/abs/2308.08468>`__.
First of all, some useful imports.
.. code:: ipython3
import torch
import matplotlib.pyplot as plt
from pina import Condition, Plotter
from pina.problem import SpatialProblem
from pina.operators import laplacian
from pina.model import FeedForward
from pina.model.layers import PeriodicBoundaryEmbedding # The PBC module
from pina.solvers import PINN
from pina.trainer import Trainer
from pina.geometry import CartesianDomain
from pina.equation import Equation
The problem definition
----------------------
The one-dimensional Helmotz problem is mathematically written as:
.. math::
\begin{cases}
\frac{d^2}{dx^2}u(x) - \lambda u(x) -f(x) &= 0 \quad x\in(0,2)\\
u^{(m)}(x=0) - u^{(m)}(x=2) &= 0 \quad m\in[0, 1, \cdots]\\
\end{cases}
In this case we are asking the solution to be :math:`C^{\infty}`
periodic with period :math:`2`, on the inifite domain
:math:`x\in(-\infty, \infty)`. Notice that the classical PINN would need
inifinite conditions to evaluate the PBC loss function, one for each
derivative, which is of course infeasable… A possible solution,
diverging from the original PINN formulation, is to use *coordinates
augmentation*. In coordinates augmentation you seek for a coordinates
transformation :math:`v` such that :math:`x\rightarrow v(x)` such that
the periodicity condition $ u^{(m)}(x=0) - u^{(m)}(x=2) = 0
:raw-latex:`\quad `m:raw-latex:`\in[0, 1, \cdots] `$ is satisfied.
For demonstration porpuses the problem specifics are
:math:`\lambda=-10\pi^2`, and
:math:`f(x)=-6\pi^2\sin(3\pi x)\cos(\pi x)` which gives a solution that
can be computed analytically :math:`u(x) = \sin(\pi x)\cos(3\pi x)`.
.. code:: ipython3
class Helmotz(SpatialProblem):
output_variables = ['u']
spatial_domain = CartesianDomain({'x': [0, 2]})
def helmotz_equation(input_, output_):
x = input_.extract('x')
u_xx = laplacian(output_, input_, components=['u'], d=['x'])
f = - 6.*torch.pi**2 * torch.sin(3*torch.pi*x)*torch.cos(torch.pi*x)
lambda_ = - 10. * torch.pi ** 2
return u_xx - lambda_ * output_ - f
# here we write the problem conditions
conditions = {
'D': Condition(location=spatial_domain,
equation=Equation(helmotz_equation)),
}
def helmotz_sol(self, pts):
return torch.sin(torch.pi * pts) * torch.cos(3. * torch.pi * pts)
truth_solution = helmotz_sol
problem = Helmotz()
# let's discretise the domain
problem.discretise_domain(200, 'grid', locations=['D'])
As usual the Helmotz problem is written in **PINA** code as a class. The
equations are written as ``conditions`` that should be satisfied in the
corresponding domains. The ``truth_solution`` is the exact solution
which will be compared with the predicted one. We used latin hypercube
sampling for choosing the collocation points.
Solving the problem with a Periodic Network
-------------------------------------------
Any :math:`\mathcal{C}^{\infty}` periodic function
:math:`u : \mathbb{R} \rightarrow \mathbb{R}` with period
:math:`L\in\mathbb{N}` can be constructed by composition of an arbitrary
smooth function :math:`f : \mathbb{R}^n \rightarrow \mathbb{R}` and a
given smooth periodic function
:math:`v : \mathbb{R} \rightarrow \mathbb{R}^n` with period :math:`L`,
that is :math:`u(x) = f(v(x))`. The formulation is generalizable for
arbitrary dimension, see `A method for representing periodic functions
and enforcing exactly periodic boundary conditions with deep neural
networks <https://arxiv.org/pdf/2007.07442>`__.
In our case, we rewrite
:math:`v(x) = \left[1, \cos\left(\frac{2\pi}{L} x\right), \sin\left(\frac{2\pi}{L} x\right)\right]`,
i.e the coordinates augmentation, and
:math:`f(\cdot) = NN_{\theta}(\cdot)` i.e. a neural network. The
resulting neural network obtained by composing :math:`f` with :math:`v`
gives the PINN approximate solution, that is
:math:`u(x) \approx u_{\theta}(x)=NN_{\theta}(v(x))`.
In **PINA** this translates in using the ``PeriodicBoundaryEmbedding`` layer for
:math:`v`, and any ``pina.model`` for :math:`NN_{\theta}`. Lets see it
in action!
.. code:: ipython3
# we encapsulate all modules in a torch.nn.Sequential container
model = torch.nn.Sequential(PeriodicBoundaryEmbedding(input_dimension=1,
periods=2),
FeedForward(input_dimensions=3, # output of PeriodicBoundaryEmbedding = 3 * input_dimension
output_dimensions=1,
layers=[10, 10]))
As simple as that! Notice in higher dimension you can specify different
periods for all dimensions using a dictionary,
e.g. ``periods={'x':2, 'y':3, ...}`` would indicate a periodicity of
:math:`2` in :math:`x`, :math:`3` in :math:`y`, and so on…
We will now sole the problem as usually with the ``PINN`` and
``Trainer`` class.
.. code:: ipython3
pinn = PINN(problem=problem, model=model)
trainer = Trainer(pinn, max_epochs=5000, accelerator='cpu', enable_model_summary=False) # we train on CPU and avoid model summary at beginning of training (optional)
trainer.train()
.. parsed-literal::
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
.. parsed-literal::
`Trainer.fit` stopped: `max_epochs=5000` reached.
.. parsed-literal::
Epoch 4999: 100%|██████████| 1/1 [00:00<00:00, 155.47it/s, v_num=20, D_loss=0.0123, mean_loss=0.0123]
We are going to plot the solution now!
.. code:: ipython3
pl = Plotter()
pl.plot(pinn)
.. image:: tutorial_files/tutorial_11_0.png
Great, they overlap perfectly! This seeams a good result, considering
the simple neural network used to some this (complex) problem. We will
now test the neural network on the domain :math:`[-4, 4]` without
retraining. In principle the periodicity should be present since the
:math:`v` function ensures the periodicity in :math:`(-\infty, \infty)`.
.. code:: ipython3
# plotting solution
with torch.no_grad():
# Notice here we put [-4, 4]!!!
new_domain = CartesianDomain({'x' : [0, 4]})
x = new_domain.sample(1000, mode='grid')
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Plot 1
axes[0].plot(x, problem.truth_solution(x), label=r'$u(x)$', color='blue')
axes[0].set_title(r'True solution $u(x)$')
axes[0].legend(loc="upper right")
# Plot 2
axes[1].plot(x, pinn(x), label=r'$u_{\theta}(x)$', color='green')
axes[1].set_title(r'PINN solution $u_{\theta}(x)$')
axes[1].legend(loc="upper right")
# Plot 3
diff = torch.abs(problem.truth_solution(x) - pinn(x))
axes[2].plot(x, diff, label=r'$|u(x) - u_{\theta}(x)|$', color='red')
axes[2].set_title(r'Absolute difference $|u(x) - u_{\theta}(x)|$')
axes[2].legend(loc="upper right")
# Adjust layout
plt.tight_layout()
# Show the plots
plt.show()
.. image:: tutorial_files/tutorial_13_0.png
It is pretty clear that the network is periodic, with also the error
following a periodic pattern. Obviusly a longer training, and a more
expressive neural network could improve the results!
Whats next?
------------
Nice you have completed the one dimensional Helmotz tutorial of
**PINA**! There are multiple directions you can go now:
1. Train the network for longer or with different layer sizes and assert
the finaly accuracy
2. Apply the ``PeriodicBoundaryEmbedding`` layer for a time-dependent problem (see
reference in the documentation)
3. Exploit extrafeature training ?
4. Many more…

Binary file not shown.

After

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 90 KiB

View File

@@ -9,6 +9,7 @@ __all__ = [
"FourierBlock2D", "FourierBlock2D",
"FourierBlock3D", "FourierBlock3D",
"PODBlock", "PODBlock",
"PeriodicBoundaryEmbedding"
] ]
from .convolution_2d import ContinuousConvBlock from .convolution_2d import ContinuousConvBlock
@@ -20,3 +21,4 @@ from .spectral import (
) )
from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D from .fourier import FourierBlock1D, FourierBlock2D, FourierBlock3D
from .pod import PODBlock from .pod import PODBlock
from .embedding import PeriodicBoundaryEmbedding

View File

@@ -0,0 +1,142 @@
""" Periodic Boundary Embedding modulus. """
import torch
from pina.utils import check_consistency
class PeriodicBoundaryEmbedding(torch.nn.Module):
r"""
Imposing hard constraint periodic boundary conditions by embedding the
input.
A periodic function :math:`u:\mathbb{R}^{\rm{in}}
\rightarrow\mathbb{R}^{\rm{out}}` periodic in the spatial
coordinates :math:`\mathbf{x}` with periods :math:`\mathbf{L}` is such that:
.. math::
u(\mathbf{x}) = u(\mathbf{x} + n \mathbf{L})\;\;
\forall n\in\mathbb{N}.
The :meth:`PeriodicBoundaryEmbedding` augments the input such that the periodic conditons
is guarantee. The input is augmented by the following formula:
.. math::
\mathbf{x} \rightarrow \tilde{\mathbf{x}} = \left[1,
\cos\left(\frac{2\pi}{L_1} x_1 \right),
\sin\left(\frac{2\pi}{L_1}x_1\right), \cdots,
\cos\left(\frac{2\pi}{L_{\rm{in}}}x_{\rm{in}}\right),
\sin\left(\frac{2\pi}{L_{\rm{in}}}x_{\rm{in}}\right)\right],
where :math:`\text{dim}(\tilde{\mathbf{x}}) = 3\text{dim}(\mathbf{x})`.
.. seealso::
**Original reference**:
1. Dong, Suchuan, and Naxian Ni (2021). *A method for representing
periodic functions and enforcing exactly periodic boundary
conditions with deep neural networks*. Journal of Computational
Physics 435, 110242.
DOI: `10.1016/j.jcp.2021.110242.
<https://doi.org/10.1016/j.jcp.2021.110242>`_
2. Wang, S., Sankaran, S., Wang, H., & Perdikaris, P. (2023). *An
expert's guide to training physics-informed neural networks*.
DOI: `arXiv preprint arXiv:2308.0846.
<https://arxiv.org/abs/2308.08468>`_
.. warning::
The embedding is a truncated fourier expansion, and only ensures
function PBC and not for its derivatives. Ensuring approximate
periodicity in
the derivatives of :math:`u` can be done, and extensive
tests have shown (also in the reference papers) that this implementation
can correctly compute the PBC on the derivatives up to the order
:math:`\sim 2,3`, while it is not guarantee the periodicity for
:math:`>3`. The PINA code is tested only for function PBC and not for
its derivatives.
"""
def __init__(self, input_dimension, periods, output_dimension=None):
"""
:param int input_dimension: The dimension of the input tensor, it can
be checked with `tensor.ndim` method.
:param float | int | dict periods: The periodicity in each dimension for
the input data. If ``float`` or ``int`` is passed,
the period is assumed constant for all the dimensions of the data.
If a ``dict`` is passed the `dict.values` represent periods,
while the ``dict.keys`` represent the dimension where the
periodicity is applied. The `dict.keys` can either be `int`
if working with ``torch.Tensor`` or ``str`` if
working with ``LabelTensor``.
:param int output_dimension: The dimension of the output after the
fourier embedding. If not ``None`` a ``torch.nn.Linear`` layer
is applied to the fourier embedding output to match the desired
dimensionality, default ``None``.
"""
super().__init__()
# check input consistency
check_consistency(periods, (float, int, dict))
check_consistency(input_dimension, int)
if output_dimension is not None:
check_consistency(output_dimension, int)
self._layer = torch.nn.Linear(input_dimension * 3, output_dimension)
else:
self._layer = torch.nn.Identity()
# checks on the periods
if isinstance(periods, dict):
if not all(isinstance(dim, (str, int)) and
isinstance(period, (float, int))
for dim, period in periods.items()):
raise TypeError('In dictionary periods, keys must be integers'
' or strings, and values must be float or int.')
self._period = periods
else:
self._period = {k: periods for k in range(input_dimension)}
def forward(self, x):
"""
Forward pass to compute the periodic boundary conditions embedding.
:param torch.Tensor x: Input tensor.
:return: Fourier embeddings of the input.
:rtype: torch.Tensor
"""
omega = torch.stack([torch.pi * 2. / torch.tensor([val],
device=x.device)
for val in self._period.values()],
dim=-1)
x = self._get_vars(x, list(self._period.keys()))
return self._layer(torch.cat([torch.ones_like(x),
torch.cos(omega * x),
torch.sin(omega * x)], dim=-1))
def _get_vars(self, x, indeces):
"""
Get variables from input tensor ordered by specific indeces.
:param torch.Tensor x: The input tensor to extract.
:param list[int] | list[str] indeces: List of indeces to extract.
:return: The extracted tensor given the indeces.
:rtype: torch.Tensor
"""
if isinstance(indeces[0], str):
try:
return x.extract(indeces)
except AttributeError:
raise RuntimeError(
'Not possible to extract input variables from tensor.'
' Ensure that the passed tensor is a LabelTensor or'
' pass list of integers to extract variables. For'
' more information refer to warning in the documentation.')
elif isinstance(indeces[0], int):
return x[..., indeces]
else:
raise RuntimeError(
'Not able to extract right indeces for tensor.'
' For more information refer to warning in the documentation.')
@property
def period(self):
"""
The period of the periodic function to approximate.
"""
return self._period

View File

@@ -0,0 +1,99 @@
import torch
import pytest
from pina.model.layers import PeriodicBoundaryEmbedding
from pina import LabelTensor
def check_same_columns(tensor):
# Get the first column
first_column = tensor[0]
# Compare each column with the first column
all_same = torch.allclose(tensor, first_column)
return all_same
def grad(u, x):
"""
Compute the first derivative of u with respect to x.
"""
return torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u),
create_graph=True, allow_unused=True,
retain_graph=True)[0]
def test_constructor():
PeriodicBoundaryEmbedding(input_dimension=1, periods=2)
PeriodicBoundaryEmbedding(input_dimension=1, periods={'x': 3, 'y' : 4})
PeriodicBoundaryEmbedding(input_dimension=1, periods={0: 3, 1 : 4})
PeriodicBoundaryEmbedding(input_dimension=1, periods=2, output_dimension=10)
with pytest.raises(TypeError):
PeriodicBoundaryEmbedding()
with pytest.raises(ValueError):
PeriodicBoundaryEmbedding(input_dimension=1., periods=1)
PeriodicBoundaryEmbedding(input_dimension=1, periods=1, output_dimension=1.)
PeriodicBoundaryEmbedding(input_dimension=1, periods={'x':'x'})
PeriodicBoundaryEmbedding(input_dimension=1, periods={0:'x'})
@pytest.mark.parametrize("period", [1, 4, 10])
@pytest.mark.parametrize("input_dimension", [1, 2, 3])
def test_forward_same_period(input_dimension, period):
func = torch.nn.Sequential(
PeriodicBoundaryEmbedding(input_dimension=input_dimension,
output_dimension=60, periods=period),
torch.nn.Tanh(),
torch.nn.Linear(60, 60),
torch.nn.Tanh(),
torch.nn.Linear(60, 1)
)
# coordinates
x = period * torch.tensor([[0.],[1.]])
if input_dimension == 2:
x = torch.cartesian_prod(x.flatten(),x.flatten())
elif input_dimension == 3:
x = torch.cartesian_prod(x.flatten(),x.flatten(),x.flatten())
x.requires_grad = True
# output
f = func(x)
assert check_same_columns(f)
def test_forward_same_period_labels():
func = torch.nn.Sequential(
PeriodicBoundaryEmbedding(input_dimension=2,
output_dimension=60, periods={'x':1, 'y':2}),
torch.nn.Tanh(),
torch.nn.Linear(60, 60),
torch.nn.Tanh(),
torch.nn.Linear(60, 1)
)
# coordinates
tensor = torch.tensor([[0., 0.], [0., 2.], [1., 0.], [1., 2.]])
with pytest.raises(RuntimeError):
func(tensor)
tensor = tensor.as_subclass(LabelTensor)
tensor.labels = ['x', 'y']
tensor.requires_grad = True
# output
f = func(tensor)
assert check_same_columns(f)
def test_forward_same_period_index():
func = torch.nn.Sequential(
PeriodicBoundaryEmbedding(input_dimension=2,
output_dimension=60, periods={0:1, 1:2}),
torch.nn.Tanh(),
torch.nn.Linear(60, 60),
torch.nn.Tanh(),
torch.nn.Linear(60, 1)
)
# coordinates
tensor = torch.tensor([[0., 0.], [0., 2.], [1., 0.], [1., 2.]])
tensor.requires_grad = True
# output
f = func(tensor)
assert check_same_columns(f)
tensor = tensor.as_subclass(LabelTensor)
tensor.labels = ['x', 'y']
# output
f = func(tensor)
assert check_same_columns(f)

1
tutorials/README.md vendored
View File

@@ -16,6 +16,7 @@ Building custom geometries with PINA `Location` class|[[.ipynb](tutorial6/tutori
Two dimensional Poisson problem using Extra Features Learning &nbsp; &nbsp; |[[.ipynb](tutorial2/tutorial.ipynb),&#160;[.py](tutorial2/tutorial.py),&#160;[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial2/tutorial.html)]| Two dimensional Poisson problem using Extra Features Learning &nbsp; &nbsp; |[[.ipynb](tutorial2/tutorial.ipynb),&#160;[.py](tutorial2/tutorial.py),&#160;[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial2/tutorial.html)]|
Two dimensional Wave problem with hard constraint |[[.ipynb](tutorial3/tutorial.ipynb),&#160;[.py](tutorial3/tutorial.py),&#160;[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial3/tutorial.html)]| Two dimensional Wave problem with hard constraint |[[.ipynb](tutorial3/tutorial.ipynb),&#160;[.py](tutorial3/tutorial.py),&#160;[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial3/tutorial.html)]|
Resolution of a 2D Poisson inverse problem |[[.ipynb](tutorial7/tutorial.ipynb),&#160;[.py](tutorial7/tutorial.py),&#160;[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial7/tutorial.html)]| Resolution of a 2D Poisson inverse problem |[[.ipynb](tutorial7/tutorial.ipynb),&#160;[.py](tutorial7/tutorial.py),&#160;[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial7/tutorial.html)]|
Periodic Boundary Conditions for Helmotz Equation |[[.ipynb](tutorial9/tutorial.ipynb),&#160;[.py](tutorial9/tutorial.py),&#160;[.html](http://mathlab.github.io/PINA/_rst/tutorials/tutorial9/tutorial.html)]|
## Neural Operator Learning ## Neural Operator Learning
| Description | Tutorial | | Description | Tutorial |

298
tutorials/tutorial9/tutorial.ipynb vendored Normal file

File diff suppressed because one or more lines are too long

191
tutorials/tutorial9/tutorial.py vendored Normal file
View File

@@ -0,0 +1,191 @@
#!/usr/bin/env python
# coding: utf-8
# # Tutorial: One dimensional Helmotz equation using Periodic Boundary Conditions
# This tutorial presents how to solve with Physics-Informed Neural Networks (PINNs)
# a one dimensional Helmotz equation with periodic boundary conditions (PBC).
# We will train with standard PINN's training by augmenting the input with
# periodic expasion as presented in [*An experts guide to training
# physics-informed neural networks*](
# https://arxiv.org/abs/2308.08468).
#
# First of all, some useful imports.
# In[1]:
import torch
import matplotlib.pyplot as plt
from pina import Condition, Plotter
from pina.problem import SpatialProblem
from pina.operators import laplacian
from pina.model import FeedForward
from pina.model.layers import PeriodicBoundaryEmbedding # The PBC module
from pina.solvers import PINN
from pina.trainer import Trainer
from pina.geometry import CartesianDomain
from pina.equation import Equation
# ## The problem definition
#
# The one-dimensional Helmotz problem is mathematically written as:
# $$
# \begin{cases}
# \frac{d^2}{dx^2}u(x) - \lambda u(x) -f(x) &= 0 \quad x\in(0,2)\\
# u^{(m)}(x=0) - u^{(m)}(x=2) &= 0 \quad m\in[0, 1, \cdots]\\
# \end{cases}
# $$
# In this case we are asking the solution to be $C^{\infty}$ periodic with
# period $2$, on the inifite domain $x\in(-\infty, \infty)$. Notice that the
# classical PINN would need inifinite conditions to evaluate the PBC loss function,
# one for each derivative, which is of course infeasable...
# A possible solution, diverging from the original PINN formulation,
# is to use *coordinates augmentation*. In coordinates augmentation you seek for
# a coordinates transformation $v$ such that $x\rightarrow v(x)$ such that
# the periodicity condition $ u^{(m)}(x=0) - u^{(m)}(x=2) = 0 \quad m\in[0, 1, \cdots] $ is
# satisfied.
#
# For demonstration porpuses the problem specifics are $\lambda=-10\pi^2$,
# and $f(x)=-6\pi^2\sin(3\pi x)\cos(\pi x)$ which gives a solution that can be
# computed analytically $u(x) = \sin(\pi x)\cos(3\pi x)$.
# In[2]:
class Helmotz(SpatialProblem):
output_variables = ['u']
spatial_domain = CartesianDomain({'x': [0, 2]})
def helmotz_equation(input_, output_):
x = input_.extract('x')
u_xx = laplacian(output_, input_, components=['u'], d=['x'])
f = - 6.*torch.pi**2 * torch.sin(3*torch.pi*x)*torch.cos(torch.pi*x)
lambda_ = - 10. * torch.pi ** 2
return u_xx - lambda_ * output_ - f
# here we write the problem conditions
conditions = {
'D': Condition(location=spatial_domain,
equation=Equation(helmotz_equation)),
}
def helmotz_sol(self, pts):
return torch.sin(torch.pi * pts) * torch.cos(3. * torch.pi * pts)
truth_solution = helmotz_sol
problem = Helmotz()
# let's discretise the domain
problem.discretise_domain(200, 'grid', locations=['D'])
# As usual the Helmotz problem is written in **PINA** code as a class.
# The equations are written as `conditions` that should be satisfied in the
# corresponding domains. The `truth_solution`
# is the exact solution which will be compared with the predicted one. We used
# latin hypercube sampling for choosing the collocation points.
# ## Solving the problem with a Periodic Network
# Any $\mathcal{C}^{\infty}$ periodic function
# $u : \mathbb{R} \rightarrow \mathbb{R}$ with period
# $L\in\mathbb{N}$ can be constructed by composition of an
# arbitrary smooth function $f : \mathbb{R}^n \rightarrow \mathbb{R}$ and a
# given smooth periodic function $v : \mathbb{R} \rightarrow \mathbb{R}^n$ with
# period $L$, that is $u(x) = f(v(x))$. The formulation is generalizable for
# arbitrary dimension, see [*A method for representing periodic functions and
# enforcing exactly periodic boundary conditions with
# deep neural networks*](https://arxiv.org/pdf/2007.07442).
#
# In our case, we rewrite
# $v(x) = \left[1, \cos\left(\frac{2\pi}{L} x\right),
# \sin\left(\frac{2\pi}{L} x\right)\right]$, i.e
# the coordinates augmentation, and $f(\cdot) = NN_{\theta}(\cdot)$ i.e. a neural
# network. The resulting neural network obtained by composing $f$ with $v$ gives
# the PINN approximate solution, that is
# $u(x) \approx u_{\theta}(x)=NN_{\theta}(v(x))$.
#
# In **PINA** this translates in using the `PeriodicBoundaryEmbedding` layer for $v$, and any
# `pina.model` for $NN_{\theta}$. Let's see it in action!
#
# In[3]:
# we encapsulate all modules in a torch.nn.Sequential container
model = torch.nn.Sequential(PeriodicBoundaryEmbedding(input_dimension=1,
periods=2),
FeedForward(input_dimensions=3, # output of PeriodicBoundaryEmbedding = 3 * input_dimension
output_dimensions=1,
layers=[10, 10]))
# As simple as that! Notice in higher dimension you can specify different periods
# for all dimensions using a dictionary, e.g. `periods={'x':2, 'y':3, ...}`
# would indicate a periodicity of $2$ in $x$, $3$ in $y$, and so on...
#
# We will now sole the problem as usually with the `PINN` and `Trainer` class.
# In[5]:
pinn = PINN(problem=problem, model=model)
trainer = Trainer(pinn, max_epochs=5000, accelerator='cpu', enable_model_summary=False) # we train on CPU and avoid model summary at beginning of training (optional)
trainer.train()
# We are going to plot the solution now!
# In[6]:
pl = Plotter()
pl.plot(pinn)
# Great, they overlap perfectly! This seeams a good result, considering the simple neural network used to some this (complex) problem. We will now test the neural network on the domain $[-4, 4]$ without retraining. In principle the periodicity should be present since the $v$ function ensures the periodicity in $(-\infty, \infty)$.
# In[7]:
# plotting solution
with torch.no_grad():
# Notice here we put [-4, 4]!!!
new_domain = CartesianDomain({'x' : [0, 4]})
x = new_domain.sample(1000, mode='grid')
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
# Plot 1
axes[0].plot(x, problem.truth_solution(x), label=r'$u(x)$', color='blue')
axes[0].set_title(r'True solution $u(x)$')
axes[0].legend(loc="upper right")
# Plot 2
axes[1].plot(x, pinn(x), label=r'$u_{\theta}(x)$', color='green')
axes[1].set_title(r'PINN solution $u_{\theta}(x)$')
axes[1].legend(loc="upper right")
# Plot 3
diff = torch.abs(problem.truth_solution(x) - pinn(x))
axes[2].plot(x, diff, label=r'$|u(x) - u_{\theta}(x)|$', color='red')
axes[2].set_title(r'Absolute difference $|u(x) - u_{\theta}(x)|$')
axes[2].legend(loc="upper right")
# Adjust layout
plt.tight_layout()
# Show the plots
plt.show()
# It is pretty clear that the network is periodic, with also the error following a periodic pattern. Obviusly a longer training, and a more expressive neural network could improve the results!
#
# ## What's next?
#
# Nice you have completed the one dimensional Helmotz tutorial of **PINA**! There are multiple directions you can go now:
#
# 1. Train the network for longer or with different layer sizes and assert the finaly accuracy
#
# 2. Apply the `PeriodicBoundaryEmbedding` layer for a time-dependent problem (see reference in the documentation)
#
# 3. Exploit extrafeature training ?
#
# 4. Many more...