Inverse problem implementation (#177)
* inverse problem implementation * add tutorial7 for inverse Poisson problem * fix doc in equation, equation_interface, system_equation --------- Co-authored-by: Dario Coscia <dariocoscia@dhcp-015.eduroam.sissa.it>
This commit is contained in:
committed by
Nicola Demo
parent
a9f14ac323
commit
0b7a307cf1
@@ -1,7 +1,7 @@
|
||||
PINA Tutorials
|
||||
==============
|
||||
|
||||
In this folder we collect useful tutorials in order to understand the principles and the potential of **PINA**.
|
||||
In this folder we collect useful tutorials in order to understand the principles and the potential of **PINA**.
|
||||
|
||||
Getting started with PINA
|
||||
-------------------------
|
||||
@@ -20,6 +20,7 @@ Physics Informed Neural Networks
|
||||
|
||||
Two dimensional Poisson problem using Extra Features Learning<tutorials/tutorial2/tutorial.rst>
|
||||
Two dimensional Wave problem with hard constraint<tutorials/tutorial3/tutorial.rst>
|
||||
Resolution of a 2D Poisson inverse problem<tutorials/tutorial7/tutorial.rst>
|
||||
|
||||
|
||||
Neural Operator Learning
|
||||
@@ -36,4 +37,5 @@ Supervised Learning
|
||||
:maxdepth: 3
|
||||
:titlesonly:
|
||||
|
||||
Unstructured convolutional autoencoder via continuous convolution<tutorials/tutorial4/tutorial.rst>
|
||||
Unstructured convolutional autoencoder via continuous convolution<tutorials/tutorial4/tutorial.rst>
|
||||
|
||||
|
||||
217
docs/source/_rst/tutorials/tutorial7/tutorial.rst
Normal file
217
docs/source/_rst/tutorials/tutorial7/tutorial.rst
Normal file
@@ -0,0 +1,217 @@
|
||||
Tutorial 7: Resolution of an inverse problem
|
||||
============================================
|
||||
|
||||
Introduction to the inverse problem
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
This tutorial shows how to solve an inverse Poisson problem with
|
||||
Physics-Informed Neural Networks. The problem definition is that of a
|
||||
Poisson problem with homogeneous boundary conditions and it reads:
|
||||
:raw-latex:`\begin{equation}
|
||||
\begin{cases}
|
||||
\Delta u = e^{-2(x-\mu_1)^2-2(y-\mu_2)^2} \text{ in } \Omega\, ,\\
|
||||
u = 0 \text{ on }\partial \Omega,\\
|
||||
u(\mu_1, \mu_2) = \text{ data}
|
||||
\end{cases}
|
||||
\end{equation}` where :math:`\Omega` is a square domain
|
||||
:math:`[-2, 2] \times [-2, 2]`, and
|
||||
:math:`\partial \Omega=\Gamma_1 \cup \Gamma_2 \cup \Gamma_3 \cup \Gamma_4`
|
||||
is the union of the boundaries of the domain.
|
||||
|
||||
This kind of problem, namely the “inverse problem”, has two main goals:
|
||||
- find the solution :math:`u` that satisfies the Poisson equation; -
|
||||
find the unknown parameters (:math:`\mu_1`, :math:`\mu_2`) that better
|
||||
fit some given data (third equation in the system above).
|
||||
|
||||
In order to achieve both the goals we will need to define an
|
||||
``InverseProblem`` in PINA.
|
||||
|
||||
Let’s start with useful imports.
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pina.problem import SpatialProblem, InverseProblem
|
||||
from pina.operators import laplacian
|
||||
from pina.model import FeedForward
|
||||
from pina.equation import Equation, FixedValue
|
||||
from pina import Condition, Trainer
|
||||
from pina.solvers import PINN
|
||||
from pina.geometry import CartesianDomain
|
||||
|
||||
Then, we import the pre-saved data, for (:math:`\mu_1`,
|
||||
:math:`\mu_2`)=(:math:`0.5`, :math:`0.5`). These two values are the
|
||||
optimal parameters that we want to find through the neural network
|
||||
training. In particular, we import the ``input_points``\ (the spatial
|
||||
coordinates), and the ``output_points`` (the corresponding :math:`u`
|
||||
values evaluated at the ``input_points``).
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
data_output = torch.load('data/pinn_solution_0.5_0.5').detach()
|
||||
data_input = torch.load('data/pts_0.5_0.5')
|
||||
|
||||
Moreover, let’s plot also the data points and the reference solution:
|
||||
this is the expected output of the neural network.
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
points = data_input.extract(['x', 'y']).detach().numpy()
|
||||
truth = data_output.detach().numpy()
|
||||
|
||||
plt.scatter(points[:, 0], points[:, 1], c=truth, s=8)
|
||||
plt.axis('equal')
|
||||
plt.colorbar()
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
.. image:: tutorial_files/output_8_0.png
|
||||
|
||||
|
||||
Inverse problem definition in PINA
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Then, we initialize the Poisson problem, that is inherited from the
|
||||
``SpatialProblem`` and from the ``InverseProblem`` classes. We here have
|
||||
to define all the variables, and the domain where our unknown parameters
|
||||
(:math:`\mu_1`, :math:`\mu_2`) belong. Notice that the laplace equation
|
||||
takes as inputs also the unknown variables, that will be treated as
|
||||
parameters that the neural network optimizes during the training
|
||||
process.
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
### Define ranges of variables
|
||||
x_min = -2
|
||||
x_max = 2
|
||||
y_min = -2
|
||||
y_max = 2
|
||||
|
||||
class Poisson(SpatialProblem, InverseProblem):
|
||||
'''
|
||||
Problem definition for the Poisson equation.
|
||||
'''
|
||||
output_variables = ['u']
|
||||
spatial_domain = CartesianDomain({'x': [x_min, x_max], 'y': [y_min, y_max]})
|
||||
# define the ranges for the parameters
|
||||
unknown_parameter_domain = CartesianDomain({'mu1': [-1, 1], 'mu2': [-1, 1]})
|
||||
|
||||
def laplace_equation(input_, output_, params_):
|
||||
'''
|
||||
Laplace equation with a force term.
|
||||
'''
|
||||
force_term = torch.exp(
|
||||
- 2*(input_.extract(['x']) - params_['mu1'])**2
|
||||
- 2*(input_.extract(['y']) - params_['mu2'])**2)
|
||||
delta_u = laplacian(output_, input_, components=['u'], d=['x', 'y'])
|
||||
|
||||
return delta_u - force_term
|
||||
|
||||
# define the conditions for the loss (boundary conditions, equation, data)
|
||||
conditions = {
|
||||
'gamma1': Condition(location=CartesianDomain({'x': [x_min, x_max],
|
||||
'y': y_max}),
|
||||
equation=FixedValue(0.0, components=['u'])),
|
||||
'gamma2': Condition(location=CartesianDomain({'x': [x_min, x_max], 'y': y_min
|
||||
}),
|
||||
equation=FixedValue(0.0, components=['u'])),
|
||||
'gamma3': Condition(location=CartesianDomain({'x': x_max, 'y': [y_min, y_max]
|
||||
}),
|
||||
equation=FixedValue(0.0, components=['u'])),
|
||||
'gamma4': Condition(location=CartesianDomain({'x': x_min, 'y': [y_min, y_max]
|
||||
}),
|
||||
equation=FixedValue(0.0, components=['u'])),
|
||||
'D': Condition(location=CartesianDomain({'x': [x_min, x_max], 'y': [y_min, y_max]
|
||||
}),
|
||||
equation=Equation(laplace_equation)),
|
||||
'data': Condition(input_points=data_input.extract(['x', 'y']), output_points=data_output)
|
||||
}
|
||||
|
||||
problem = Poisson()
|
||||
|
||||
Then, we define the model of the neural network we want to use. Here we
|
||||
used a model which impose hard constrains on the boundary conditions, as
|
||||
also done in the Wave tutorial!
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
model = FeedForward(
|
||||
layers=[20, 20, 20],
|
||||
func=torch.nn.Softplus,
|
||||
output_dimensions=len(problem.output_variables),
|
||||
input_dimensions=len(problem.input_variables)
|
||||
)
|
||||
|
||||
After that, we discretize the spatial domain.
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
problem.discretise_domain(20, 'grid', locations=['D'], variables=['x', 'y'])
|
||||
problem.discretise_domain(1000, 'random', locations=['gamma1', 'gamma2',
|
||||
'gamma3', 'gamma4'], variables=['x', 'y'])
|
||||
|
||||
Here, we define a simple callback for the trainer. We use this callback
|
||||
to save the parameters predicted by the neural network during the
|
||||
training. The parameters are saved every 100 epochs as ``torch`` tensors
|
||||
in a specified directory (``tmp_dir`` in our case). The goal is to read
|
||||
the saved parameters after training and plot their trend across the
|
||||
epochs.
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
# temporary directory for saving logs of training
|
||||
tmp_dir = "tmp_poisson_inverse"
|
||||
|
||||
class SaveParameters(Callback):
|
||||
'''
|
||||
Callback to save the parameters of the model every 100 epochs.
|
||||
'''
|
||||
def on_train_epoch_end(self, trainer, __):
|
||||
if trainer.current_epoch % 100 == 99:
|
||||
torch.save(trainer.solver.problem.unknown_parameters, '{}/parameters_epoch{}'.format(tmp_dir, trainer.current_epoch))
|
||||
|
||||
Then, we define the ``PINN`` object and train the solver using the
|
||||
``Trainer``.
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
### train the problem with PINN
|
||||
max_epochs = 5000
|
||||
pinn = PINN(problem, model, optimizer_kwargs={'lr':0.005})
|
||||
# define the trainer for the solver
|
||||
trainer = Trainer(solver=pinn, accelerator='cpu', max_epochs=max_epochs,
|
||||
default_root_dir=tmp_dir, callbacks=[SaveParameters()])
|
||||
trainer.train()
|
||||
|
||||
One can now see how the parameters vary during the training by reading
|
||||
the saved solution and plotting them. The plot shows that the parameters
|
||||
stabilize to their true value before reaching the epoch :math:`1000`!
|
||||
|
||||
.. code:: ipython3
|
||||
|
||||
epochs_saved = range(99, max_epochs, 100)
|
||||
parameters = torch.empty((int(max_epochs/100), 2))
|
||||
for i, epoch in enumerate(epochs_saved):
|
||||
params_torch = torch.load('{}/parameters_epoch{}'.format(tmp_dir, epoch))
|
||||
for e, var in enumerate(pinn.problem.unknown_variables):
|
||||
parameters[i, e] = params_torch[var].data
|
||||
|
||||
# Plot parameters
|
||||
plt.close()
|
||||
plt.plot(epochs_saved, parameters[:, 0], label='mu1', marker='o')
|
||||
plt.plot(epochs_saved, parameters[:, 1], label='mu2', marker='s')
|
||||
plt.ylim(-1, 1)
|
||||
plt.grid()
|
||||
plt.legend()
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Parameter value')
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
.. image:: tutorial_files/output_21_0.png
|
||||
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 10 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 112 KiB |
@@ -37,7 +37,7 @@ class Condition:
|
||||
>>> example_input_pts = LabelTensor(
|
||||
>>> torch.tensor([[0, 0, 0]]), ['x', 'y', 'z'])
|
||||
>>> example_output_pts = LabelTensor(torch.tensor([[1, 2]]), ['a', 'b'])
|
||||
>>>
|
||||
>>>
|
||||
>>> Condition(
|
||||
>>> input_points=example_input_pts,
|
||||
>>> output_points=example_output_pts)
|
||||
|
||||
@@ -9,4 +9,4 @@ __all__ = [
|
||||
|
||||
from .equation import Equation
|
||||
from .equation_factory import FixedFlux, FixedGradient, Laplace, FixedValue
|
||||
from .system_equation import SystemEquation
|
||||
from .system_equation import SystemEquation
|
||||
|
||||
@@ -8,7 +8,7 @@ class Equation(EquationInterface):
|
||||
"""
|
||||
Equation class for specifing any equation in PINA.
|
||||
Each ``equation`` passed to a ``Condition`` object
|
||||
must be an ``Equation`` or ``SystemEquation``.
|
||||
must be an ``Equation`` or ``SystemEquation``.
|
||||
|
||||
:param equation: A ``torch`` callable equation to
|
||||
evaluate the residual.
|
||||
@@ -20,14 +20,26 @@ class Equation(EquationInterface):
|
||||
f'{equation}')
|
||||
self.__equation = equation
|
||||
|
||||
def residual(self, input_, output_):
|
||||
def residual(self, input_, output_, params_ = None):
|
||||
"""
|
||||
Residual computation of the equation.
|
||||
|
||||
:param LabelTensor input_: Input points to evaluate the equation.
|
||||
:param LabelTensor output_: Output vectors given my a model (e.g,
|
||||
:param LabelTensor output_: Output vectors given by a model (e.g,
|
||||
a ``FeedForward`` model).
|
||||
:param dict params_: Dictionary of parameters related to the inverse
|
||||
problem (if any).
|
||||
If the equation is not related to an ``InverseProblem``, the
|
||||
parameters are initialized to ``None`` and the residual is
|
||||
computed as ``equation(input_, output_)``.
|
||||
Otherwise, the parameters are automatically initialized in the
|
||||
ranges specified by the user.
|
||||
|
||||
:return: The residual evaluation of the specified equation.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
return self.__equation(input_, output_)
|
||||
if params_ is None:
|
||||
result = self.__equation(input_, output_)
|
||||
else:
|
||||
result = self.__equation(input_, output_, params_)
|
||||
return result
|
||||
|
||||
@@ -11,3 +11,17 @@ class EquationInterface(metaclass=ABCMeta):
|
||||
the output variables, the condition(s), and the domain(s) where the
|
||||
conditions are applied.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def residual(self, input_, output_, params_):
|
||||
"""
|
||||
Residual computation of the equation.
|
||||
|
||||
:param LabelTensor input_: Input points to evaluate the equation.
|
||||
:param LabelTensor output_: Output vectors given by my model (e.g., a ``FeedForward`` model).
|
||||
:param dict params_: Dictionary of unknown parameters, eventually
|
||||
related to an ``InverseProblem``.
|
||||
:return: The residual evaluation of the specified equation.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -11,14 +11,14 @@ class SystemEquation(Equation):
|
||||
System of Equation class for specifing any system
|
||||
of equations in PINA.
|
||||
Each ``equation`` passed to a ``Condition`` object
|
||||
must be an ``Equation`` or ``SystemEquation``.
|
||||
A ``SystemEquation`` is specified by a list of
|
||||
must be an ``Equation`` or ``SystemEquation``.
|
||||
A ``SystemEquation`` is specified by a list of
|
||||
equations.
|
||||
|
||||
:param Callable equation: A ``torch`` callable equation to
|
||||
evaluate the residual
|
||||
:param str reduction: Specifies the reduction to apply to the output:
|
||||
``none`` | ``mean`` | ``sum`` | ``callable``. ``none``: no reduction
|
||||
``none`` | ``mean`` | ``sum`` | ``callable``. ``none``: no reduction
|
||||
will be applied, ``mean``: the sum of the output will be divided
|
||||
by the number of elements in the output, ``sum``: the output will
|
||||
be summed. ``callable`` a callable function to perform reduction,
|
||||
@@ -43,19 +43,28 @@ class SystemEquation(Equation):
|
||||
raise NotImplementedError(
|
||||
'Only mean and sum reductions implemented.')
|
||||
|
||||
def residual(self, input_, output_):
|
||||
def residual(self, input_, output_, params_=None):
|
||||
"""
|
||||
Residual computation of the equation.
|
||||
Residual computation for the equations of the system.
|
||||
|
||||
:param LabelTensor input_: Input points to evaluate the equation.
|
||||
:param LabelTensor output_: Output vectors given my a model (e.g,
|
||||
:param LabelTensor input_: Input points to evaluate the system of
|
||||
equations.
|
||||
:param LabelTensor output_: Output vectors given by a model (e.g,
|
||||
a ``FeedForward`` model).
|
||||
:return: The residual evaluation of the specified equation,
|
||||
:param dict params_: Dictionary of parameters related to the inverse
|
||||
problem (if any).
|
||||
If the equation is not related to an ``InverseProblem``, the
|
||||
parameters are initialized to ``None`` and the residual is
|
||||
computed as ``equation(input_, output_)``.
|
||||
Otherwise, the parameters are automatically initialized in the
|
||||
ranges specified by the user.
|
||||
|
||||
:return: The residual evaluation of the specified system of equations,
|
||||
aggregated by the ``reduction`` defined in the ``__init__``.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
residual = torch.hstack(
|
||||
[equation.residual(input_, output_) for equation in self.equations])
|
||||
[equation.residual(input_, output_, params_) for equation in self.equations])
|
||||
|
||||
if self.reduction == 'none':
|
||||
return residual
|
||||
|
||||
@@ -205,6 +205,7 @@ class Plotter:
|
||||
plt.savefig(filename)
|
||||
else:
|
||||
plt.show()
|
||||
plt.close()
|
||||
|
||||
def plot_loss(self,
|
||||
trainer,
|
||||
|
||||
@@ -3,9 +3,11 @@ __all__ = [
|
||||
'SpatialProblem',
|
||||
'TimeDependentProblem',
|
||||
'ParametricProblem',
|
||||
'InverseProblem',
|
||||
]
|
||||
|
||||
from .abstract_problem import AbstractProblem
|
||||
from .spatial_problem import SpatialProblem
|
||||
from .timedep_problem import TimeDependentProblem
|
||||
from .parametric_problem import ParametricProblem
|
||||
from .inverse_problem import InverseProblem
|
||||
|
||||
@@ -109,6 +109,14 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
samples = condition.input_points
|
||||
self.input_pts[condition_name] = samples
|
||||
self._have_sampled_points[condition_name] = True
|
||||
if hasattr(self, 'unknown_parameter_domain'):
|
||||
# initialize the unknown parameters of the inverse problem given
|
||||
# the domain the user gives
|
||||
self.unknown_parameters = {}
|
||||
for i, var in enumerate(self.unknown_variables):
|
||||
range_var = self.unknown_parameter_domain.range_[var]
|
||||
tensor_var = torch.rand(1, requires_grad=True) * range_var[1] + range_var[0]
|
||||
self.unknown_parameters[var] = torch.nn.Parameter(tensor_var)
|
||||
|
||||
def discretise_domain(self,
|
||||
n,
|
||||
@@ -203,6 +211,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
self.input_variables):
|
||||
self._have_sampled_points[location] = True
|
||||
|
||||
|
||||
def add_points(self, new_points):
|
||||
"""
|
||||
Adding points to the already sampled points.
|
||||
@@ -237,7 +246,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
@property
|
||||
def have_sampled_points(self):
|
||||
"""
|
||||
Check if all points for
|
||||
Check if all points for
|
||||
``Location`` are sampled.
|
||||
"""
|
||||
return all(self._have_sampled_points.values())
|
||||
@@ -245,7 +254,7 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
@property
|
||||
def not_sampled_points(self):
|
||||
"""
|
||||
Check which points are
|
||||
Check which points are
|
||||
not sampled.
|
||||
"""
|
||||
# variables which are not sampled
|
||||
@@ -257,3 +266,4 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
if not is_sample:
|
||||
not_sampled.append(condition_name)
|
||||
return not_sampled
|
||||
|
||||
|
||||
71
pina/problem/inverse_problem.py
Normal file
71
pina/problem/inverse_problem.py
Normal file
@@ -0,0 +1,71 @@
|
||||
"""Module for the ParametricProblem class"""
|
||||
from abc import abstractmethod
|
||||
|
||||
from .abstract_problem import AbstractProblem
|
||||
|
||||
|
||||
class InverseProblem(AbstractProblem):
|
||||
"""
|
||||
The class for the definition of inverse problems, i.e., problems
|
||||
with unknown parameters that have to be learned during the training process
|
||||
from given data.
|
||||
|
||||
Here's an example of a spatial inverse ODE problem, i.e., a spatial
|
||||
ODE problem with an unknown parameter `alpha` as coefficient of the
|
||||
derivative term.
|
||||
|
||||
:Example:
|
||||
>>> from pina.problem import SpatialProblem, InverseProblem
|
||||
>>> from pina.operators import grad
|
||||
>>> from pina.equation import ParametricEquation, FixedValue
|
||||
>>> from pina import Condition
|
||||
>>> from pina.geometry import CartesianDomain
|
||||
>>> import torch
|
||||
>>>
|
||||
>>> class InverseODE(SpatialProblem, InverseProblem):
|
||||
>>>
|
||||
>>> output_variables = ['u']
|
||||
>>> spatial_domain = CartesianDomain({'x': [0, 1]})
|
||||
>>> unknown_parameter_domain = CartesianDomain({'alpha': [1, 10]})
|
||||
>>>
|
||||
>>> def ode_equation(input_, output_, params_):
|
||||
>>> u_x = grad(output_, input_, components=['u'], d=['x'])
|
||||
>>> u = output_.extract(['u'])
|
||||
>>> return params_.extract(['alpha']) * u_x - u
|
||||
>>>
|
||||
>>> def solution_data(input_, output_):
|
||||
>>> x = input_.extract(['x'])
|
||||
>>> solution = torch.exp(x)
|
||||
>>> return output_ - solution
|
||||
>>>
|
||||
>>> conditions = {
|
||||
>>> 'x0': Condition(CartesianDomain({'x': 0}), FixedValue(1.0)),
|
||||
>>> 'D': Condition(CartesianDomain({'x': [0, 1]}), ParametricEquation(ode_equation)),
|
||||
>>> 'data': Condition(CartesianDomain({'x': [0, 1]}), Equation(solution_data))
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def unknown_parameter_domain(self):
|
||||
"""
|
||||
The parameters' domain of the problem.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def unknown_variables(self):
|
||||
"""
|
||||
The parameters of the problem.
|
||||
"""
|
||||
return self.unknown_parameter_domain.variables
|
||||
|
||||
@property
|
||||
def unknown_parameters(self):
|
||||
"""
|
||||
The parameters of the problem.
|
||||
"""
|
||||
return self.__unknown_parameters
|
||||
|
||||
@unknown_parameters.setter
|
||||
def unknown_parameters(self, value):
|
||||
self.__unknown_parameters = value
|
||||
|
||||
@@ -14,7 +14,7 @@ class SpatialProblem(AbstractProblem):
|
||||
:Example:
|
||||
>>> from pina.problem import SpatialProblem
|
||||
>>> from pina.operators import grad
|
||||
>>> from pina.equations import Equation, FixedValue
|
||||
>>> from pina.equation import Equation, FixedValue
|
||||
>>> from pina import Condition
|
||||
>>> from pina.geometry import CartesianDomain
|
||||
>>> import torch
|
||||
@@ -33,7 +33,6 @@ class SpatialProblem(AbstractProblem):
|
||||
>>> conditions = {
|
||||
>>> 'x0': Condition(CartesianDomain({'x': 0, 'alpha':[1, 10]}), FixedValue(1.)),
|
||||
>>> 'D': Condition(CartesianDomain({'x': [0, 1], 'alpha':[1, 10]}), Equation(ode_equation))}
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -14,7 +14,7 @@ class TimeDependentProblem(AbstractProblem):
|
||||
:Example:
|
||||
>>> from pina.problem import SpatialProblem, TimeDependentProblem
|
||||
>>> from pina.operators import grad, laplacian
|
||||
>>> from pina.equations import Equation, FixedValue
|
||||
>>> from pina.equation import Equation, FixedValue
|
||||
>>> from pina import Condition
|
||||
>>> from pina.geometry import CartesianDomain
|
||||
>>> import torch
|
||||
@@ -43,7 +43,6 @@ class TimeDependentProblem(AbstractProblem):
|
||||
>>> 'gamma1': Condition(CartesianDomain({'x':0, 't':[0, 1]}), FixedValue(0.)),
|
||||
>>> 'gamma2': Condition(CartesianDomain({'x':3, 't':[0, 1]}), FixedValue(0.)),
|
||||
>>> 'D': Condition(CartesianDomain({'x': [0, 3], 't':[0, 1]}), Equation(wave_equation))}
|
||||
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -11,6 +11,7 @@ from .solver import SolverInterface
|
||||
from ..label_tensor import LabelTensor
|
||||
from ..utils import check_consistency
|
||||
from ..loss import LossInterface
|
||||
from ..problem import InverseProblem
|
||||
from torch.nn.modules.loss import _Loss
|
||||
|
||||
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||
@@ -18,14 +19,14 @@ torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
|
||||
|
||||
class PINN(SolverInterface):
|
||||
"""
|
||||
PINN solver class. This class implements Physics Informed Neural
|
||||
PINN solver class. This class implements Physics Informed Neural
|
||||
Network solvers, using a user specified ``model`` to solve a specific
|
||||
``problem``.
|
||||
``problem``. It can be used for solving both forward and inverse problems.
|
||||
|
||||
.. seealso::
|
||||
|
||||
**Original reference**: Karniadakis, G. E., Kevrekidis, I. G., Lu, L.,
|
||||
Perdikaris, P., Wang, S., & Yang, L. (2021).
|
||||
**Original reference**: Karniadakis, G. E., Kevrekidis, I. G., Lu, L.,
|
||||
Perdikaris, P., Wang, S., & Yang, L. (2021).
|
||||
Physics-informed machine learning. Nature Reviews Physics, 3(6), 422-440.
|
||||
<https://doi.org/10.1038/s42254-021-00314-5>`_.
|
||||
"""
|
||||
@@ -45,7 +46,7 @@ class PINN(SolverInterface):
|
||||
},
|
||||
):
|
||||
'''
|
||||
:param AbstractProblem problem: The formualation of the problem.
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use.
|
||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
||||
default :class:`torch.nn.MSELoss`.
|
||||
@@ -74,12 +75,18 @@ class PINN(SolverInterface):
|
||||
self._loss = loss
|
||||
self._neural_net = self.models[0]
|
||||
|
||||
# inverse problem handling
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self._params = self.problem.unknown_parameters
|
||||
else:
|
||||
self._params = None
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Forward pass implementation for the PINN
|
||||
solver.
|
||||
|
||||
:param torch.Tensor x: Input tensor.
|
||||
:param torch.Tensor x: Input tensor.
|
||||
:return: PINN solution.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
@@ -93,17 +100,30 @@ class PINN(SolverInterface):
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
"""
|
||||
# if the problem is an InverseProblem, add the unknown parameters
|
||||
# to the parameters that the optimizer needs to optimize
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self.optimizers[0].add_param_group(
|
||||
{'params': [self._params[var] for var in self.problem.unknown_variables]}
|
||||
)
|
||||
return self.optimizers, [self.scheduler]
|
||||
|
||||
|
||||
def _clamp_inverse_problem_params(self):
|
||||
for v in self._params:
|
||||
self._params[v].data.clamp_(
|
||||
self.problem.unknown_parameter_domain.range_[v][0],
|
||||
self.problem.unknown_parameter_domain.range_[v][1])
|
||||
|
||||
def _loss_data(self, input, output):
|
||||
return self.loss(self.forward(input), output)
|
||||
|
||||
|
||||
def _loss_phys(self, samples, equation):
|
||||
residual = equation.residual(samples, self.forward(samples))
|
||||
try:
|
||||
residual = equation.residual(samples, self.forward(samples))
|
||||
except TypeError: # this occurs when the function has three inputs, i.e. inverse problem
|
||||
residual = equation.residual(samples, self.forward(samples), self._params)
|
||||
return self.loss(torch.zeros_like(residual, requires_grad=True), residual)
|
||||
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""
|
||||
PINN solver training step.
|
||||
@@ -137,15 +157,20 @@ class PINN(SolverInterface):
|
||||
else:
|
||||
raise ValueError("Batch size not supported")
|
||||
|
||||
# TODO for users this us hard to remebeber when creating a new solver, to fix in a smarter way
|
||||
# TODO for users this us hard to remember when creating a new solver, to fix in a smarter way
|
||||
loss = loss.as_subclass(torch.Tensor)
|
||||
|
||||
# add condition losses and accumulate logging for each epoch
|
||||
# # add condition losses and accumulate logging for each epoch
|
||||
condition_losses.append(loss * condition.data_weight)
|
||||
self.log(condition_name + '_loss', float(loss),
|
||||
prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||
|
||||
# add to tot loss and accumulate logging for each epoch
|
||||
# clamp unknown parameters of the InverseProblem to their domain ranges (if needed)
|
||||
if isinstance(self.problem, InverseProblem):
|
||||
self._clamp_inverse_problem_params()
|
||||
|
||||
# TODO Fix the bug, tot_loss is a label tensor without labels
|
||||
# we need to pass it as a torch tensor to make everything work
|
||||
total_loss = sum(condition_losses)
|
||||
self.log('mean_loss', float(total_loss / len(condition_losses)),
|
||||
prog_bar=True, logger=True, on_epoch=True, on_step=False)
|
||||
|
||||
9
tutorials/README.md
vendored
9
tutorials/README.md
vendored
@@ -6,20 +6,21 @@ In this folder we collect useful tutorials in order to understand the principles
|
||||
|
||||
| Description | Tutorial |
|
||||
|---------------|-----------|
|
||||
Introduction to PINA for Physics Informed Neural Networks training|[[.ipynb](tutorial1/tutorial.ipynb), [.py](tutorial1/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorial1/tutorial.html)]|
|
||||
Building custom geometries with PINA `Location` class|[[.ipynb](tutorial1/tutorial.ipynb), [.py](tutorial1/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorial1/tutorial.html)]|
|
||||
Introduction to PINA for Physics Informed Neural Networks training|[[.ipynb](tutorial1/tutorial.ipynb), [.py](tutorial1/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorial1/tutorial.html)]|
|
||||
Building custom geometries with PINA `Location` class|[[.ipynb](tutorial1/tutorial.ipynb), [.py](tutorial1/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorial1/tutorial.html)]|
|
||||
|
||||
|
||||
## Physics Informed Neural Networks
|
||||
| Description | Tutorial |
|
||||
| Description | Tutorial |
|
||||
|---------------|-----------|
|
||||
Two dimensional Poisson problem using Extra Features Learning |[[.ipynb](tutorial2/tutorial.ipynb), [.py](tutorial2/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorial2/tutorial.html)]|
|
||||
Two dimensional Wave problem with hard constraint |[[.ipynb](tutorial3/tutorial.ipynb), [.py](tutorial3/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorial3/tutorial.html)]|
|
||||
Resolution of a 2D Poisson inverse problem |[[.ipynb](tutorial7/tutorial.ipynb), [.py](tutorial7/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorial7/tutorial.html)]|
|
||||
|
||||
## Neural Operator Learning
|
||||
| Description | Tutorial |
|
||||
|---------------|-----------|
|
||||
Two dimensional Darcy flow using the Fourier Neural Operator |[[.ipynb](tutorial5/tutorial.ipynb), [.py](tutorial5/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorial5/tutorial.html)]|
|
||||
Two dimensional Darcy flow using the Fourier Neural Operator |[[.ipynb](tutorial5/tutorial.ipynb), [.py](tutorial5/tutorial.py), [.html](http://mathlab.github.io/PINA/_rst/tutorial5/tutorial.html)]|
|
||||
|
||||
## Supervised Learning
|
||||
| Description | Tutorial |
|
||||
|
||||
BIN
tutorials/tutorial7/data/pinn_solution_0.5_0.5
vendored
Normal file
BIN
tutorials/tutorial7/data/pinn_solution_0.5_0.5
vendored
Normal file
Binary file not shown.
BIN
tutorials/tutorial7/data/pts_0.5_0.5
vendored
Normal file
BIN
tutorials/tutorial7/data/pts_0.5_0.5
vendored
Normal file
Binary file not shown.
368
tutorials/tutorial7/tutorial.ipynb
vendored
Normal file
368
tutorials/tutorial7/tutorial.ipynb
vendored
Normal file
File diff suppressed because one or more lines are too long
197
tutorials/tutorial7/tutorial.py
vendored
Normal file
197
tutorials/tutorial7/tutorial.py
vendored
Normal file
@@ -0,0 +1,197 @@
|
||||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
|
||||
# # Tutorial 7: Resolution of an inverse problem
|
||||
|
||||
# ### Introduction to the inverse problem
|
||||
|
||||
# This tutorial shows how to solve an inverse Poisson problem with Physics-Informed Neural Networks. The problem definition is that of a Poisson problem with homogeneous boundary conditions and it reads:
|
||||
# \begin{equation}
|
||||
# \begin{cases}
|
||||
# \Delta u = e^{-2(x-\mu_1)^2-2(y-\mu_2)^2} \text{ in } \Omega\, ,\\
|
||||
# u = 0 \text{ on }\partial \Omega,\\
|
||||
# u(\mu_1, \mu_2) = \text{ data}
|
||||
# \end{cases}
|
||||
# \end{equation}
|
||||
# where $\Omega$ is a square domain $[-2, 2] \times [-2, 2]$, and $\partial \Omega=\Gamma_1 \cup \Gamma_2 \cup \Gamma_3 \cup \Gamma_4$ is the union of the boundaries of the domain.
|
||||
#
|
||||
# This kind of problem, namely the "inverse problem", has two main goals:
|
||||
# - find the solution $u$ that satisfies the Poisson equation;
|
||||
# - find the unknown parameters ($\mu_1$, $\mu_2$) that better fit some given data (third equation in the system above).
|
||||
#
|
||||
# In order to achieve both the goals we will need to define an `InverseProblem` in PINA.
|
||||
|
||||
# Let's start with useful imports.
|
||||
|
||||
# In[1]:
|
||||
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import torch
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
from pina.problem import SpatialProblem, InverseProblem
|
||||
from pina.operators import laplacian
|
||||
from pina.model import FeedForward
|
||||
from pina.equation import Equation, FixedValue
|
||||
from pina import Condition, Trainer
|
||||
from pina.solvers import PINN
|
||||
from pina.geometry import CartesianDomain
|
||||
|
||||
|
||||
# Then, we import the pre-saved data, for ($\mu_1$, $\mu_2$)=($0.5$, $0.5$). These two values are the optimal parameters that we want to find through the neural network training. In particular, we import the `input_points`(the spatial coordinates), and the `output_points` (the corresponding $u$ values evaluated at the `input_points`).
|
||||
|
||||
# In[2]:
|
||||
|
||||
|
||||
data_output = torch.load('data/pinn_solution_0.5_0.5').detach()
|
||||
data_input = torch.load('data/pts_0.5_0.5')
|
||||
|
||||
|
||||
# Moreover, let's plot also the data points and the reference solution: this is the expected output of the neural network.
|
||||
|
||||
# In[3]:
|
||||
|
||||
|
||||
points = data_input.extract(['x', 'y']).detach().numpy()
|
||||
truth = data_output.detach().numpy()
|
||||
|
||||
plt.scatter(points[:, 0], points[:, 1], c=truth, s=8)
|
||||
plt.axis('equal')
|
||||
plt.colorbar()
|
||||
plt.show()
|
||||
|
||||
|
||||
# ### Inverse problem definition in PINA
|
||||
|
||||
# Then, we initialize the Poisson problem, that is inherited from the `SpatialProblem` and from the `InverseProblem` classes. We here have to define all the variables, and the domain where our unknown parameters ($\mu_1$, $\mu_2$) belong. Notice that the laplace equation takes as inputs also the unknown variables, that will be treated as parameters that the neural network optimizes during the training process.
|
||||
|
||||
# In[4]:
|
||||
|
||||
|
||||
### Define ranges of variables
|
||||
x_min = -2
|
||||
x_max = 2
|
||||
y_min = -2
|
||||
y_max = 2
|
||||
|
||||
class Poisson(SpatialProblem, InverseProblem):
|
||||
'''
|
||||
Problem definition for the Poisson equation.
|
||||
'''
|
||||
output_variables = ['u']
|
||||
spatial_domain = CartesianDomain({'x': [x_min, x_max], 'y': [y_min, y_max]})
|
||||
# define the ranges for the parameters
|
||||
unknown_parameter_domain = CartesianDomain({'mu1': [-1, 1], 'mu2': [-1, 1]})
|
||||
|
||||
def laplace_equation(input_, output_, params_):
|
||||
'''
|
||||
Laplace equation with a force term.
|
||||
'''
|
||||
force_term = torch.exp(
|
||||
- 2*(input_.extract(['x']) - params_['mu1'])**2
|
||||
- 2*(input_.extract(['y']) - params_['mu2'])**2)
|
||||
delta_u = laplacian(output_, input_, components=['u'], d=['x', 'y'])
|
||||
|
||||
return delta_u - force_term
|
||||
|
||||
# define the conditions for the loss (boundary conditions, equation, data)
|
||||
conditions = {
|
||||
'gamma1': Condition(location=CartesianDomain({'x': [x_min, x_max],
|
||||
'y': y_max}),
|
||||
equation=FixedValue(0.0, components=['u'])),
|
||||
'gamma2': Condition(location=CartesianDomain({'x': [x_min, x_max], 'y': y_min
|
||||
}),
|
||||
equation=FixedValue(0.0, components=['u'])),
|
||||
'gamma3': Condition(location=CartesianDomain({'x': x_max, 'y': [y_min, y_max]
|
||||
}),
|
||||
equation=FixedValue(0.0, components=['u'])),
|
||||
'gamma4': Condition(location=CartesianDomain({'x': x_min, 'y': [y_min, y_max]
|
||||
}),
|
||||
equation=FixedValue(0.0, components=['u'])),
|
||||
'D': Condition(location=CartesianDomain({'x': [x_min, x_max], 'y': [y_min, y_max]
|
||||
}),
|
||||
equation=Equation(laplace_equation)),
|
||||
'data': Condition(input_points=data_input.extract(['x', 'y']), output_points=data_output)
|
||||
}
|
||||
|
||||
problem = Poisson()
|
||||
|
||||
|
||||
# Then, we define the model of the neural network we want to use. Here we used a model which impose hard constrains on the boundary conditions, as also done in the Wave tutorial!
|
||||
|
||||
# In[5]:
|
||||
|
||||
|
||||
model = FeedForward(
|
||||
layers=[20, 20, 20],
|
||||
func=torch.nn.Softplus,
|
||||
output_dimensions=len(problem.output_variables),
|
||||
input_dimensions=len(problem.input_variables)
|
||||
)
|
||||
|
||||
|
||||
# After that, we discretize the spatial domain.
|
||||
|
||||
# In[6]:
|
||||
|
||||
|
||||
problem.discretise_domain(20, 'grid', locations=['D'], variables=['x', 'y'])
|
||||
problem.discretise_domain(1000, 'random', locations=['gamma1', 'gamma2',
|
||||
'gamma3', 'gamma4'], variables=['x', 'y'])
|
||||
|
||||
|
||||
# Here, we define a simple callback for the trainer. We use this callback to save the parameters predicted by the neural network during the training. The parameters are saved every 100 epochs as `torch` tensors in a specified directory (`tmp_dir` in our case).
|
||||
# The goal is to read the saved parameters after training and plot their trend across the epochs.
|
||||
|
||||
# In[7]:
|
||||
|
||||
|
||||
# temporary directory for saving logs of training
|
||||
tmp_dir = "tmp_poisson_inverse"
|
||||
|
||||
class SaveParameters(Callback):
|
||||
'''
|
||||
Callback to save the parameters of the model every 100 epochs.
|
||||
'''
|
||||
def on_train_epoch_end(self, trainer, __):
|
||||
if trainer.current_epoch % 100 == 99:
|
||||
torch.save(trainer.solver.problem.unknown_parameters, '{}/parameters_epoch{}'.format(tmp_dir, trainer.current_epoch))
|
||||
|
||||
|
||||
# Then, we define the `PINN` object and train the solver using the `Trainer`.
|
||||
|
||||
# In[8]:
|
||||
|
||||
|
||||
### train the problem with PINN
|
||||
max_epochs = 5000
|
||||
pinn = PINN(problem, model, optimizer_kwargs={'lr':0.005})
|
||||
# define the trainer for the solver
|
||||
trainer = Trainer(solver=pinn, accelerator='cpu', max_epochs=max_epochs,
|
||||
default_root_dir=tmp_dir, callbacks=[SaveParameters()])
|
||||
trainer.train()
|
||||
|
||||
|
||||
# One can now see how the parameters vary during the training by reading the saved solution and plotting them. The plot shows that the parameters stabilize to their true value before reaching the epoch $1000$!
|
||||
|
||||
# In[9]:
|
||||
|
||||
|
||||
epochs_saved = range(99, max_epochs, 100)
|
||||
parameters = torch.empty((int(max_epochs/100), 2))
|
||||
for i, epoch in enumerate(epochs_saved):
|
||||
params_torch = torch.load('{}/parameters_epoch{}'.format(tmp_dir, epoch))
|
||||
for e, var in enumerate(pinn.problem.unknown_variables):
|
||||
parameters[i, e] = params_torch[var].data
|
||||
|
||||
# Plot parameters
|
||||
plt.close()
|
||||
plt.plot(epochs_saved, parameters[:, 0], label='mu1', marker='o')
|
||||
plt.plot(epochs_saved, parameters[:, 1], label='mu2', marker='s')
|
||||
plt.ylim(-1, 1)
|
||||
plt.grid()
|
||||
plt.legend()
|
||||
plt.xlabel('Epoch')
|
||||
plt.ylabel('Parameter value')
|
||||
plt.show()
|
||||
|
||||
Reference in New Issue
Block a user