fix pinn doc

This commit is contained in:
giovanni
2025-03-13 16:50:05 +01:00
committed by Nicola Demo
parent 9a26c94e07
commit 28ef4c823b
8 changed files with 377 additions and 337 deletions

View File

@@ -1,4 +1,4 @@
"""Module for Residual-Based Attention PINN."""
"""Module for the Residual-Based Attention PINN solver."""
from copy import deepcopy
import torch
@@ -9,14 +9,14 @@ from ...utils import check_consistency
class RBAPINN(PINN):
r"""
Residual-based Attention PINN (RBAPINN) solver class.
This class implements Residual-based Attention Physics Informed Neural
Network solver, using a user specified ``model`` to solve a specific
``problem``. It can be used for solving both forward and inverse problems.
Residual-based Attention Physics-Informed Neural Network (RBAPINN) solver
class. This class implements the Residual-based Attention Physics-Informed
Neural Network solver, using a user specified ``model`` to solve a specific
``problem``. It can be used to solve both forward and inverse problems.
The Residual-based Attention Physics Informed Neural Network aims to find
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
of the differential problem:
The Residual-based Attention Physics-Informed Neural Network solver aims to
find the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` of a
differential problem:
.. math::
@@ -26,7 +26,7 @@ class RBAPINN(PINN):
\mathbf{x}\in\partial\Omega
\end{cases}
minimizing the loss function
minimizing the loss function:
.. math::
@@ -38,23 +38,23 @@ class RBAPINN(PINN):
\left( \mathcal{B}[\mathbf{u}](\mathbf{x})
\right),
denoting the weights as
denoting the weights as:
:math:`\lambda_{\Omega}^1, \dots, \lambda_{\Omega}^{N_\Omega}` and
:math:`\lambda_{\partial \Omega}^1, \dots,
\lambda_{\Omega}^{N_\partial \Omega}`
for :math:`\Omega` and :math:`\partial \Omega`, respectively.
Residual-based Attention Physics Informed Neural Network computes
the weights by updating them at every epoch as follows
Residual-based Attention Physics-Informed Neural Network updates the weights
of the residuals at every epoch as follows:
.. math::
\lambda_i^{k+1} \leftarrow \gamma\lambda_i^{k} +
\eta\frac{\lvert r_i\rvert}{\max_j \lvert r_j\rvert},
where :math:`r_i` denotes the residual at point :math:`i`,
:math:`\gamma` denotes the decay rate, and :math:`\eta` is
the learning rate for the weights' update.
where :math:`r_i` denotes the residual at point :math:`i`, :math:`\gamma`
denotes the decay rate, and :math:`\eta` is the learning rate for the
weights' update.
.. seealso::
**Original reference**: Sokratis J. Anagnostopoulos, Juan D. Toscano,
@@ -78,20 +78,25 @@ class RBAPINN(PINN):
gamma=0.999,
):
"""
:param torch.nn.Module model: The neural network model to use.
:param AbstractProblem problem: The formulation of the problem.
:param torch.optim.Optimizer optimizer: The neural network optimizer to
use; default `None`.
:param torch.optim.LRScheduler scheduler: Learning rate scheduler;
default `None`.
:param WeightingInterface weighting: The weighting schema to use;
default `None`.
:param torch.nn.Module loss: The loss function to be minimized;
default `None`.
Initialization of the :class:`RBAPINN` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module model: The neural network model to be used.
:param torch.optim.Optimizer optimizer: The optimizer to be used.
If `None`, the Adam optimizer is used. Default is ``None``.
:param torch.optim.LRScheduler scheduler: Learning rate scheduler.
If `None`, the constant learning rate scheduler is used.
Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the Mean Squared Error (MSE) loss is used.
Default is `None`.
:param float | int eta: The learning rate for the weights of the
residual; default 0.001.
residuals. Default is ``0.001``.
:param float gamma: The decay parameter in the update of the weights
of the residual. Must be between 0 and 1; default 0.999.
of the residuals. Must be between ``0`` and ``1``.
Default is ``0.999``.
"""
super().__init__(
model=model,
@@ -122,6 +127,11 @@ class RBAPINN(PINN):
# for now RBAPINN is implemented only for batch_size = None
def on_train_start(self):
"""
Hook method called at the beginning of training.
:raises NotImplementedError: If the batch size is not ``None``.
"""
if self.trainer.batch_size is not None:
raise NotImplementedError(
"RBAPINN only works with full batch "
@@ -132,11 +142,11 @@ class RBAPINN(PINN):
def _vect_to_scalar(self, loss_value):
"""
Elaboration of the pointwise loss.
Computation of the scalar loss.
:param LabelTensor loss_value: the matrix of pointwise loss.
:return: the scalar loss.
:param LabelTensor loss_value: the tensor of pointwise losses.
:raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``.
:return: The computed scalar loss.
:rtype LabelTensor
"""
if self.loss.reduction == "mean":
@@ -152,14 +162,12 @@ class RBAPINN(PINN):
def loss_phys(self, samples, equation):
"""
Computes the physics loss for the residual-based attention PINN
solver based on given samples and equation.
Computes the physics loss for the physics-informed solver based on the
provided samples and equation.
:param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation
representing the physics.
:return: The physics loss calculated based on given
samples and equation.
:param EquationInterface equation: The governing equation.
:return: The computed physics loss.
:rtype: LabelTensor
"""
residual = self.compute_residual(samples=samples, equation=equation)