fix pinn doc

This commit is contained in:
giovanni
2025-03-13 16:50:05 +01:00
committed by Nicola Demo
parent 9a26c94e07
commit 28ef4c823b
8 changed files with 377 additions and 337 deletions

View File

@@ -1,4 +1,4 @@
"""TODO""" """Module for the physics-informed solvers."""
__all__ = [ __all__ = [
"PINNInterface", "PINNInterface",

View File

@@ -1,4 +1,4 @@
"""Module for Causal PINN.""" """Module for the Causal PINN solver."""
import torch import torch
@@ -9,14 +9,13 @@ from ...utils import check_consistency
class CausalPINN(PINN): class CausalPINN(PINN):
r""" r"""
Causal Physics Informed Neural Network (CausalPINN) solver class. Causal Physics-Informed Neural Network (CausalPINN) solver class.
This class implements Causal Physics Informed Neural This class implements the Causal Physics-Informed Neural Network solver,
Network solver, using a user specified ``model`` to solve a specific using a user specified ``model`` to solve a specific ``problem``.
``problem``. It can be used for solving both forward and inverse problems. It can be used to solve both forward and inverse problems.
The Causal Physics Informed Network aims to find The Causal Physics-Informed Neural Network solver aims to find the solution
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` of a differential problem:
of the differential problem:
.. math:: .. math::
@@ -26,7 +25,7 @@ class CausalPINN(PINN):
\mathbf{x}\in\partial\Omega \mathbf{x}\in\partial\Omega
\end{cases} \end{cases}
minimizing the loss function minimizing the loss function:
.. math:: .. math::
\mathcal{L}_{\rm{problem}} = \frac{1}{N_t}\sum_{i=1}^{N_t} \mathcal{L}_{\rm{problem}} = \frac{1}{N_t}\sum_{i=1}^{N_t}
@@ -45,14 +44,12 @@ class CausalPINN(PINN):
.. math:: .. math::
\omega_i = \exp\left(\epsilon \sum_{k=1}^{i-1}\mathcal{L}_r(t_k)\right). \omega_i = \exp\left(\epsilon \sum_{k=1}^{i-1}\mathcal{L}_r(t_k)\right).
:math:`\epsilon` is an hyperparameter, default set to :math:`100`, while :math:`\epsilon` is an hyperparameter, set by default to :math:`100`, while
:math:`\mathcal{L}` is a specific loss function, :math:`\mathcal{L}` is a specific loss function, typically the MSE:
default Mean Square Error:
.. math:: .. math::
\mathcal{L}(v) = \| v \|^2_2. \mathcal{L}(v) = \| v \|^2_2.
.. seealso:: .. seealso::
**Original reference**: Wang, Sifan, Shyam Sankaran, and Paris **Original reference**: Wang, Sifan, Shyam Sankaran, and Paris
@@ -62,9 +59,8 @@ class CausalPINN(PINN):
DOI `10.1016 <https://doi.org/10.1016/j.cma.2024.116813>`_. DOI `10.1016 <https://doi.org/10.1016/j.cma.2024.116813>`_.
.. note:: .. note::
This class can only work for problems inheriting This class is only compatible with problems that inherit from the
from at least :class:`~pina.problem.TimeDependentProblem` class.
:class:`~pina.problem.timedep_problem.TimeDependentProblem` class.
""" """
def __init__( def __init__(
@@ -78,17 +74,23 @@ class CausalPINN(PINN):
eps=100, eps=100,
): ):
""" """
:param torch.nn.Module model: The neural network model to use. Initialization of the :class:`CausalPINN` class.
:param AbstractProblem problem: The formulation of the problem.
:param torch.optim.Optimizer optimizer: The neural network optimizer to :param AbstractProblem problem: The problem to be solved. It must
use; default `None`. inherit from at least :class:`~pina.problem.TimeDependentProblem`.
:param torch.optim.LRScheduler scheduler: Learning rate scheduler; :param torch.nn.Module model: The neural network model to be used.
default `None`. :param torch.optim.Optimizer optimizer: The optimizer to be used
:param WeightingInterface weighting: The weighting schema to use; If `None`, the Adam optimizer is used. Default is ``None``.
default `None`. :param torch.optim.LRScheduler scheduler: Learning rate scheduler.
:param torch.nn.Module loss: The loss function to be minimized; If `None`, the constant learning rate scheduler is used.
default `None`. Default is ``None``.
:param float eps: The exponential decay parameter; default `100`. :param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the Mean Squared Error (MSE) loss is used.
Default is `None`.
:param float eps: The exponential decay parameter. Default is ``100``.
:raises ValueError: If the problem is not a TimeDependentProblem.
""" """
super().__init__( super().__init__(
model=model, model=model,
@@ -110,14 +112,12 @@ class CausalPINN(PINN):
def loss_phys(self, samples, equation): def loss_phys(self, samples, equation):
""" """
Computes the physics loss for the Causal PINN solver based on given Computes the physics loss for the physics-informed solver based on the
samples and equation. provided samples and equation.
:param LabelTensor samples: The samples to evaluate the physics loss. :param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation :param EquationInterface equation: The governing equation.
representing the physics. :return: The computed physics loss.
:return: The physics loss calculated based on given
samples and equation.
:rtype: LabelTensor :rtype: LabelTensor
""" """
# split sequentially ordered time tensors into chunks # split sequentially ordered time tensors into chunks
@@ -146,13 +146,16 @@ class CausalPINN(PINN):
def eps(self): def eps(self):
""" """
The exponential decay parameter. The exponential decay parameter.
:return: The exponential decay parameter.
:rtype: float
""" """
return self._eps return self._eps
@eps.setter @eps.setter
def eps(self, value): def eps(self, value):
""" """
Setter method for the eps parameter. Set the exponential decay parameter.
:param float value: The exponential decay parameter. :param float value: The exponential decay parameter.
""" """
@@ -161,10 +164,10 @@ class CausalPINN(PINN):
def _sort_label_tensor(self, tensor): def _sort_label_tensor(self, tensor):
""" """
Sorts the label tensor based on time variables. Sort the tensor with respect to the temporal variables.
:param LabelTensor tensor: The label tensor to be sorted. :param LabelTensor tensor: The tensor to be sorted.
:return: The sorted label tensor based on time variables. :return: The tensor sorted with respect to the temporal variables.
:rtype: LabelTensor :rtype: LabelTensor
""" """
# labels input tensors # labels input tensors
@@ -179,11 +182,12 @@ class CausalPINN(PINN):
def _split_tensor_into_chunks(self, tensor): def _split_tensor_into_chunks(self, tensor):
""" """
Splits the label tensor into chunks based on time. Split the tensor into chunks based on time.
:param LabelTensor tensor: The label tensor to be split. :param LabelTensor tensor: The tensor to be split.
:return: Tuple containing the chunks and the original labels. :return: A tuple containing the list of tensor chunks and the
:rtype: Tuple[List[LabelTensor], List] corresponding labels.
:rtype: tuple[list[LabelTensor], list[str]]
""" """
# extract labels # extract labels
labels = tensor.labels labels = tensor.labels
@@ -199,7 +203,7 @@ class CausalPINN(PINN):
def _compute_weights(self, loss): def _compute_weights(self, loss):
""" """
Computes the weights for the physics loss based on the cumulative loss. Compute the weights for the physics loss based on the cumulative loss.
:param LabelTensor loss: The physics loss values. :param LabelTensor loss: The physics loss values.
:return: The computed weights for the physics loss. :return: The computed weights for the physics loss.

View File

@@ -1,4 +1,4 @@
"""Module for Competitive PINN.""" """Module for the Competitive PINN solver."""
import copy import copy
import torch import torch
@@ -10,14 +10,14 @@ from ..solver import MultiSolverInterface
class CompetitivePINN(PINNInterface, MultiSolverInterface): class CompetitivePINN(PINNInterface, MultiSolverInterface):
r""" r"""
Competitive Physics Informed Neural Network (PINN) solver class. Competitive Physics-Informed Neural Network (CompetitivePINN) solver class.
This class implements Competitive Physics Informed Neural This class implements the Competitive Physics-Informed Neural Network
Network solver, using a user specified ``model`` to solve a specific solver, using a user specified ``model`` to solve a specific ``problem``.
``problem``. It can be used for solving both forward and inverse problems. It can be used to solve both forward and inverse problems.
The Competitive Physics Informed Network aims to find The Competitive Physics-Informed Neural Network solver aims to find the
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` of a differential
of the differential problem: problem:
.. math:: .. math::
@@ -27,18 +27,18 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
\mathbf{x}\in\partial\Omega \mathbf{x}\in\partial\Omega
\end{cases} \end{cases}
with a minimization (on ``model`` parameters) maximation ( minimizing the loss function with respect to the model parameters, while
on ``discriminator`` parameters) of the loss function maximizing it with respect to the discriminator parameters:
.. math:: .. math::
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N \mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
\mathcal{L}(D(\mathbf{x}_i)\mathcal{A}[\mathbf{u}](\mathbf{x}_i))+ \mathcal{L}(D(\mathbf{x}_i)\mathcal{A}[\mathbf{u}](\mathbf{x}_i))+
\frac{1}{N}\sum_{i=1}^N \frac{1}{N}\sum_{i=1}^N
\mathcal{L}(D(\mathbf{x}_i)\mathcal{B}[\mathbf{u}](\mathbf{x}_i)) \mathcal{L}(D(\mathbf{x}_i)\mathcal{B}[\mathbf{u}](\mathbf{x}_i)),
where :math:`D` is the discriminator network, which tries to find the points where :math:D is the discriminator network, which identifies the points
where the network performs worst, and :math:`\mathcal{L}` is a specific loss where the model performs worst, and :math:\mathcal{L} is a specific loss
function, default Mean Square Error: function, typically the MSE:
.. math:: .. math::
\mathcal{L}(v) = \| v \|^2_2. \mathcal{L}(v) = \| v \|^2_2.
@@ -49,10 +49,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
"Competitive physics informed networks." International Conference on "Competitive physics informed networks." International Conference on
Learning Representations, ICLR 2022 Learning Representations, ICLR 2022
`OpenReview Preprint <https://openreview.net/forum?id=z9SIj-IM7tn>`_. `OpenReview Preprint <https://openreview.net/forum?id=z9SIj-IM7tn>`_.
.. warning::
This solver does not currently support the possibility to pass
``extra_feature``.
""" """
def __init__( def __init__(
@@ -68,24 +64,30 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
loss=None, loss=None,
): ):
""" """
:param AbstractProblem problem: The formulation of the problem. Initialization of the :class:`CompetitivePINN` class.
:param torch.nn.Module model: The neural network model to use
for the model. :param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module discriminator: The neural network model to use :param torch.nn.Module model: The neural network model to be used.
for the discriminator. If ``None``, the discriminator network will :param torch.nn.Module discriminator: The discriminator to be used.
have the same architecture as the model network. If `None`, the discriminator is a deepcopy of the ``model``.
:param torch.optim.Optimizer optimizer_model: The neural network Default is ``None``.
optimizer to use for the model network; default `None`. :param torch.optim.Optimizer optimizer_model: The optimizer of the
:param torch.optim.Optimizer optimizer_discriminator: The neural network ``model``. If `None`, the Adam optimizer is used.
optimizer to use for the discriminator network; default `None`. Default is ``None``.
:param torch.optim.Optimizer optimizer_discriminator: The optimizer of
the ``discriminator``. If `None`, the Adam optimizer is used.
Default is ``None``.
:param torch.optim.LRScheduler scheduler_model: Learning rate scheduler :param torch.optim.LRScheduler scheduler_model: Learning rate scheduler
for the model; default `None`. for the ``model``. If `None`, the constant learning rate scheduler
is used. Default is ``None``.
:param torch.optim.LRScheduler scheduler_discriminator: Learning rate :param torch.optim.LRScheduler scheduler_discriminator: Learning rate
scheduler for the discriminator; default `None`. scheduler for the ``discriminator``. If `None`, the constant
:param WeightingInterface weighting: The weighting schema to use; learning rate scheduler is used. Default is ``None``.
default `None`. :param WeightingInterface weighting: The weighting schema to be used.
:param torch.nn.Module loss: The loss function to be minimized; If `None`, no weighting schema is used. Default is ``None``.
default `None`. :param torch.nn.Module loss: The loss function to be minimized.
If `None`, the Mean Squared Error (MSE) loss is used.
Default is `None`.
""" """
if discriminator is None: if discriminator is None:
discriminator = copy.deepcopy(model) discriminator = copy.deepcopy(model)
@@ -103,15 +105,11 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
self.automatic_optimization = False self.automatic_optimization = False
def forward(self, x): def forward(self, x):
r""" """
Forward pass implementation for the PINN solver. It returns the function Forward pass.
evaluation :math:`\mathbf{u}(\mathbf{x})` at the control points
:math:`\mathbf{x}`.
:param LabelTensor x: Input tensor for the PINN solver. It expects :param LabelTensor x: Input tensor.
a tensor :math:`N \times D`, where :math:`N` the number of points :return: The output of the neural network.
in the mesh, :math:`D` the dimension of the problem,
:return: PINN solution evaluated at contro points.
:rtype: LabelTensor :rtype: LabelTensor
""" """
return self.neural_net(x) return self.neural_net(x)
@@ -120,9 +118,8 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
""" """
Solver training step, overridden to perform manual optimization. Solver training step, overridden to perform manual optimization.
:param batch: The batch element in the dataloader. :param dict batch: The batch element in the dataloader.
:type batch: tuple :return: The aggregated loss.
:return: The sum of the loss functions.
:rtype: LabelTensor :rtype: LabelTensor
""" """
# train model # train model
@@ -139,14 +136,12 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
def loss_phys(self, samples, equation): def loss_phys(self, samples, equation):
""" """
Computes the physics loss for the Competitive PINN solver based on given Computes the physics loss for the physics-informed solver based on the
samples and equation. provided samples and equation.
:param LabelTensor samples: The samples to evaluate the physics loss. :param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation :param EquationInterface equation: The governing equation.
representing the physics. :return: The computed physics loss.
:return: The physics loss calculated based on given
samples and equation.
:rtype: LabelTensor :rtype: LabelTensor
""" """
# Compute discriminator bets # Compute discriminator bets
@@ -165,7 +160,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
def configure_optimizers(self): def configure_optimizers(self):
""" """
Optimizer configuration for the Competitive PINN solver. Optimizer configuration.
:return: The optimizers and the schedulers :return: The optimizers and the schedulers
:rtype: tuple(list, list) :rtype: tuple(list, list)
@@ -198,16 +193,13 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
def on_train_batch_end(self, outputs, batch, batch_idx): def on_train_batch_end(self, outputs, batch, batch_idx):
""" """
This method is called at the end of each training batch, and ovverides This method is called at the end of each training batch and overrides
the PytorchLightining implementation for logging the checkpoints. the PyTorch Lightning implementation to log checkpoints.
:param torch.Tensor outputs: The output from the model for the :param torch.Tensor outputs: The ``model``'s output for the current
current batch. batch.
:param tuple batch: The current batch of data. :param dict batch: The current batch of data.
:param int batch_idx: The index of the current batch. :param int batch_idx: The index of the current batch.
:return: Whatever is returned by the parent
method ``on_train_batch_end``.
:rtype: Any
""" """
# increase by one the counter of optimization to save loggers # increase by one the counter of optimization to save loggers
( (
@@ -219,9 +211,9 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
@property @property
def neural_net(self): def neural_net(self):
""" """
Returns the neural network model. The model.
:return: The neural network model. :return: The model.
:rtype: torch.nn.Module :rtype: torch.nn.Module
""" """
return self.models[0] return self.models[0]
@@ -229,9 +221,9 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
@property @property
def discriminator(self): def discriminator(self):
""" """
Returns the discriminator model (if applicable). The discriminator.
:return: The discriminator model. :return: The discriminator.
:rtype: torch.nn.Module :rtype: torch.nn.Module
""" """
return self.models[1] return self.models[1]
@@ -239,9 +231,9 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
@property @property
def optimizer_model(self): def optimizer_model(self):
""" """
Returns the optimizer associated with the neural network model. The optimizer associated to the model.
:return: The optimizer for the neural network model. :return: The optimizer for the model.
:rtype: torch.optim.Optimizer :rtype: torch.optim.Optimizer
""" """
return self.optimizers[0] return self.optimizers[0]
@@ -249,7 +241,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
@property @property
def optimizer_discriminator(self): def optimizer_discriminator(self):
""" """
Returns the optimizer associated with the discriminator (if applicable). The optimizer associated to the discriminator.
:return: The optimizer for the discriminator. :return: The optimizer for the discriminator.
:rtype: torch.optim.Optimizer :rtype: torch.optim.Optimizer
@@ -259,9 +251,9 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
@property @property
def scheduler_model(self): def scheduler_model(self):
""" """
Returns the scheduler associated with the neural network model. The scheduler associated to the model.
:return: The scheduler for the neural network model. :return: The scheduler for the model.
:rtype: torch.optim.lr_scheduler._LRScheduler :rtype: torch.optim.lr_scheduler._LRScheduler
""" """
return self.schedulers[0] return self.schedulers[0]
@@ -269,7 +261,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
@property @property
def scheduler_discriminator(self): def scheduler_discriminator(self):
""" """
Returns the scheduler associated with the discriminator (if applicable). The scheduler associated to the discriminator.
:return: The scheduler for the discriminator. :return: The scheduler for the discriminator.
:rtype: torch.optim.lr_scheduler._LRScheduler :rtype: torch.optim.lr_scheduler._LRScheduler

View File

@@ -1,4 +1,4 @@
"""Module for Gradient PINN.""" """Module for the Gradient PINN solver."""
import torch import torch
@@ -9,14 +9,14 @@ from ...problem import SpatialProblem
class GradientPINN(PINN): class GradientPINN(PINN):
r""" r"""
Gradient Physics Informed Neural Network (GradientPINN) solver class. Gradient Physics-Informed Neural Network (GradientPINN) solver class.
This class implements Gradient Physics Informed Neural This class implements the Gradient Physics-Informed Neural Network solver,
Network solver, using a user specified ``model`` to solve a specific using a user specified ``model`` to solve a specific ``problem``.
``problem``. It can be used for solving both forward and inverse problems. It can be used to solve both forward and inverse problems.
The Gradient Physics Informed Network aims to find The Gradient Physics-Informed Neural Network solver aims to find the
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` of a differential
of the differential problem: problem:
.. math:: .. math::
@@ -26,7 +26,7 @@ class GradientPINN(PINN):
\mathbf{x}\in\partial\Omega \mathbf{x}\in\partial\Omega
\end{cases} \end{cases}
minimizing the loss function minimizing the loss function;
.. math:: .. math::
\mathcal{L}_{\rm{problem}} =& \frac{1}{N}\sum_{i=1}^N \mathcal{L}_{\rm{problem}} =& \frac{1}{N}\sum_{i=1}^N
@@ -39,8 +39,7 @@ class GradientPINN(PINN):
\nabla_{\mathbf{x}}\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i)) \nabla_{\mathbf{x}}\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i))
where :math:`\mathcal{L}` is a specific loss function, where :math:`\mathcal{L}` is a specific loss function, typically the MSE:
default Mean Square Error:
.. math:: .. math::
\mathcal{L}(v) = \| v \|^2_2. \mathcal{L}(v) = \| v \|^2_2.
@@ -54,9 +53,8 @@ class GradientPINN(PINN):
DOI: `10.1016 <https://doi.org/10.1016/j.cma.2022.114823>`_. DOI: `10.1016 <https://doi.org/10.1016/j.cma.2022.114823>`_.
.. note:: .. note::
This class can only work for problems inheriting This class is only compatible with problems that inherit from the
from at least :class:`~pina.problem.spatial_problem.SpatialProblem` :class:`~pina.problem.SpatialProblem` class.
class.
""" """
def __init__( def __init__(
@@ -69,19 +67,23 @@ class GradientPINN(PINN):
loss=None, loss=None,
): ):
""" """
:param torch.nn.Module model: The neural network model to use. Initialization of the :class:`GradientPINN` class.
:param AbstractProblem problem: The formulation of the problem. It must
inherit from at least :param AbstractProblem problem: The problem to be solved.
:class:`~pina.problem.spatial_problem.SpatialProblem` to compute It must inherit from at least :class:`~pina.problem.SpatialProblem`
the gradient of the loss. to compute the gradient of the loss.
:param torch.optim.Optimizer optimizer: The neural network optimizer to :param torch.nn.Module model: The neural network model to be used.
use; default `None`. :param torch.optim.Optimizer optimizer: The optimizer to be used.
:param torch.optim.LRScheduler scheduler: Learning rate scheduler; If `None`, the Adam optimizer is used. Default is ``None``.
default `None`. :param torch.optim.LRScheduler scheduler: Learning rate scheduler.
:param WeightingInterface weighting: The weighting schema to use; If `None`, the constant learning rate scheduler is used.
default `None`. Default is ``None``.
:param torch.nn.Module loss: The loss function to be minimized; :param WeightingInterface weighting: The weighting schema to be used.
default `None`. If `None`, no weighting schema is used. Default is ``None``.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the Mean Squared Error (MSE) loss is used.
Default is `None`.
:raises ValueError: If the problem is not a SpatialProblem.
""" """
super().__init__( super().__init__(
model=model, model=model,
@@ -102,14 +104,12 @@ class GradientPINN(PINN):
def loss_phys(self, samples, equation): def loss_phys(self, samples, equation):
""" """
Computes the physics loss for the GPINN solver based on given Computes the physics loss for the physics-informed solver based on the
samples and equation. provided samples and equation.
:param LabelTensor samples: The samples to evaluate the physics loss. :param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation :param EquationInterface equation: The governing equation.
representing the physics. :return: The computed physics loss.
:return: The physics loss calculated based on given
samples and equation.
:rtype: LabelTensor :rtype: LabelTensor
""" """
# classical PINN loss # classical PINN loss

View File

@@ -1,4 +1,4 @@
"""Module for Physics Informed Neural Network.""" """Module for the Physics-Informed Neural Network solver."""
import torch import torch
@@ -9,14 +9,13 @@ from ...problem import InverseProblem
class PINN(PINNInterface, SingleSolverInterface): class PINN(PINNInterface, SingleSolverInterface):
r""" r"""
Physics Informed Neural Network (PINN) solver class. Physics-Informed Neural Network (PINN) solver class.
This class implements Physics Informed Neural This class implements Physics-Informed Neural Network solver, using a user
Network solver, using a user specified ``model`` to solve a specific specified ``model`` to solve a specific ``problem``.
``problem``. It can be used for solving both forward and inverse problems. It can be used to solve both forward and inverse problems.
The Physics Informed Network aims to find The Physics Informed Neural Network solver aims to find the solution
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` of a differential problem:
of the differential problem:
.. math:: .. math::
@@ -26,16 +25,15 @@ class PINN(PINNInterface, SingleSolverInterface):
\mathbf{x}\in\partial\Omega \mathbf{x}\in\partial\Omega
\end{cases} \end{cases}
minimizing the loss function minimizing the loss function:
.. math:: .. math::
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N \mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
\mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i)) + \mathcal{L}(\mathcal{A}[\mathbf{u}](\mathbf{x}_i)) +
\frac{1}{N}\sum_{i=1}^N \frac{1}{N}\sum_{i=1}^N
\mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i)) \mathcal{L}(\mathcal{B}[\mathbf{u}](\mathbf{x}_i)),
where :math:`\mathcal{L}` is a specific loss function, where :math:`\mathcal{L}` is a specific loss function, typically the MSE:
default Mean Square Error:
.. math:: .. math::
\mathcal{L}(v) = \| v \|^2_2. \mathcal{L}(v) = \| v \|^2_2.
@@ -58,16 +56,20 @@ class PINN(PINNInterface, SingleSolverInterface):
loss=None, loss=None,
): ):
""" """
:param torch.nn.Module model: The neural network model to use. Initialization of the :class:`PINN` class.
:param AbstractProblem problem: The formulation of the problem.
:param torch.optim.Optimizer optimizer: The neural network optimizer to :param AbstractProblem problem: The problem to be solved.
use; default `None`. :param torch.nn.Module model: The neural network model to be used.
:param torch.optim.LRScheduler scheduler: Learning rate scheduler; :param torch.optim.Optimizer optimizer: The optimizer to be used.
default `None`. If `None`, the Adam optimizer is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to use; :param torch.optim.LRScheduler scheduler: Learning rate scheduler.
default `None`. If `None`, the constant learning rate scheduler is used.
:param torch.nn.Module loss: The loss function to be minimized; Default is ``None``.
default `None`. :param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the Mean Squared Error (MSE) loss is used.
Default is `None`.
""" """
super().__init__( super().__init__(
model=model, model=model,
@@ -80,14 +82,12 @@ class PINN(PINNInterface, SingleSolverInterface):
def loss_phys(self, samples, equation): def loss_phys(self, samples, equation):
""" """
Computes the physics loss for the PINN solver based on given Computes the physics loss for the physics-informed solver based on the
samples and equation. provided samples and equation.
:param LabelTensor samples: The samples to evaluate the physics loss. :param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation :param EquationInterface equation: The governing equation.
representing the physics. :return: The computed physics loss.
:return: The physics loss calculated based on given
samples and equation.
:rtype: LabelTensor :rtype: LabelTensor
""" """
residual = self.compute_residual(samples=samples, equation=equation) residual = self.compute_residual(samples=samples, equation=equation)

View File

@@ -1,4 +1,4 @@
"""Module for Physics Informed Neural Network Interface.""" """Module for the Physics Informed Neural Network Interface."""
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
import torch import torch
@@ -17,14 +17,13 @@ from ...condition import (
class PINNInterface(SolverInterface, metaclass=ABCMeta): class PINNInterface(SolverInterface, metaclass=ABCMeta):
""" """
Base PINN solver class. This class implements the Solver Interface Base class for Physics-Informed Neural Network (PINN) solvers, implementing
for Physics Informed Neural Network solver. the :class:`~pina.solver.SolverInterface` class.
This class can be used to define PINNs with multiple ``optimizers``, The `PINNInterface` class can be used to define PINNs that work with one or
and/or ``models``. multiple optimizers and/or models. By default, it is compatible with
By default it takes :class:`~pina.problem.abstract_problem.AbstractProblem`, problems defined by :class:`~pina.problem.AbstractProblem`, and users can
so the user can choose what type of problem the implemented solver, choose the problem type the solver is meant to address.
inheriting from this class, is designed to solve.
""" """
accepted_conditions_types = ( accepted_conditions_types = (
@@ -35,9 +34,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
def __init__(self, problem, loss=None, **kwargs): def __init__(self, problem, loss=None, **kwargs):
""" """
:param AbstractProblem problem: A problem definition instance. Initialization of the :class:`PINNInterface` class.
:param torch.nn.Module loss: The loss function to be minimized,
default `None`. :param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, the Mean Squared Error (MSE) loss is used.
Default is ``None``.
:param kwargs: Additional keyword arguments to be passed to the
:class:`~pina.solver.SolverInterface` class.
""" """
if loss is None: if loss is None:
@@ -62,10 +66,28 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
self.__metric = None self.__metric = None
def optimization_cycle(self, batch): def optimization_cycle(self, batch):
"""
The optimization cycle for the PINN solver.
This method allows to call `_run_optimization_cycle` with the physics
loss as argument, thus distinguishing the training step from the
validation and test steps.
:param dict batch: The batch of data to use in the optimization cycle.
:return: The loss of the optimization cycle.
:rtype: torch.Tensor
"""
return self._run_optimization_cycle(batch, self.loss_phys) return self._run_optimization_cycle(batch, self.loss_phys)
@torch.set_grad_enabled(True) @torch.set_grad_enabled(True)
def validation_step(self, batch): def validation_step(self, batch):
"""
The validation step for the PINN solver.
:param dict batch: The batch of data to use in the validation step.
:return: The loss of the validation step.
:rtype: torch.Tensor
"""
losses = self._run_optimization_cycle(batch, self._residual_loss) losses = self._run_optimization_cycle(batch, self._residual_loss)
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor) loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
self.store_log("val_loss", loss, self.get_batch_size(batch)) self.store_log("val_loss", loss, self.get_batch_size(batch))
@@ -73,6 +95,13 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
@torch.set_grad_enabled(True) @torch.set_grad_enabled(True)
def test_step(self, batch): def test_step(self, batch):
"""
The test step for the PINN solver.
:param dict batch: The batch of data to use in the test step.
:return: The loss of the test step.
:rtype: torch.Tensor
"""
losses = self._run_optimization_cycle(batch, self._residual_loss) losses = self._run_optimization_cycle(batch, self._residual_loss)
loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor) loss = self.weighting.aggregate(losses).as_subclass(torch.Tensor)
self.store_log("test_loss", loss, self.get_batch_size(batch)) self.store_log("test_loss", loss, self.get_batch_size(batch))
@@ -80,14 +109,14 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
def loss_data(self, input_pts, output_pts): def loss_data(self, input_pts, output_pts):
""" """
The data loss for the PINN solver. It computes the loss between Compute the data loss for the PINN solver by evaluating the loss
the network output against the true solution. This function between the network's output and the true solution. This method
should not be override if not intentionally. should only be overridden intentionally.
:param LabelTensor input_pts: The input to the neural networks. :param LabelTensor input_pts: The input points to the neural network.
:param LabelTensor output_pts: The true solution to compare the :param LabelTensor output_pts: The true solution to compare with the
network solution. network's output.
:return: The residual loss averaged on the input coordinates :return: The supervised loss, averaged over the number of observations.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
return self._loss(self.forward(input_pts), output_pts) return self._loss(self.forward(input_pts), output_pts)
@@ -95,28 +124,23 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
@abstractmethod @abstractmethod
def loss_phys(self, samples, equation): def loss_phys(self, samples, equation):
""" """
Computes the physics loss for the physics informed solver based on given Computes the physics loss for the physics-informed solver based on the
samples and equation. This method must be override by all inherited provided samples and equation. This method must be overridden in
classes and it is the core to define a new physics informed solver. subclasses. It distinguishes different types of PINN solvers.
:param LabelTensor samples: The samples to evaluate the physics loss. :param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation :param EquationInterface equation: The governing equation.
representing the physics. :return: The computed physics loss.
:return: The physics loss calculated based on given
samples and equation.
:rtype: LabelTensor :rtype: LabelTensor
""" """
def compute_residual(self, samples, equation): def compute_residual(self, samples, equation):
""" """
Compute the residual for Physics Informed learning. This function Compute the residuals of the equation.
returns the :obj:`~pina.equation.equation.Equation` specified in the
:obj:`~pina.condition.Condition` evaluated at the ``samples`` points.
:param LabelTensor samples: The samples to evaluate the physics loss. :param LabelTensor samples: The samples to evaluate the loss.
:param EquationInterface equation: The governing equation :param EquationInterface equation: The governing equation.
representing the physics. :return: The residual of the solution of the model.
:return: The residual of the neural network solution.
:rtype: LabelTensor :rtype: LabelTensor
""" """
try: try:
@@ -129,10 +153,27 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
return residual return residual
def _residual_loss(self, samples, equation): def _residual_loss(self, samples, equation):
"""
Compute the residual loss.
:param LabelTensor samples: The samples to evaluate the loss.
:param EquationInterface equation: The governing equation.
:return: The residual loss.
:rtype: torch.Tensor
"""
residuals = self.compute_residual(samples, equation) residuals = self.compute_residual(samples, equation)
return self.loss(residuals, torch.zeros_like(residuals)) return self.loss(residuals, torch.zeros_like(residuals))
def _run_optimization_cycle(self, batch, loss_residuals): def _run_optimization_cycle(self, batch, loss_residuals):
"""
Compute, given a batch, the loss for each condition and return a
dictionary with the condition name as key and the loss as value.
:param dict batch: The batch of data to use in the optimization cycle.
:param function loss_residuals: The loss function to be minimized.
:return: The loss for each condition.
:rtype dict
"""
condition_loss = {} condition_loss = {}
for condition_name, points in batch: for condition_name, points in batch:
self.__metric = condition_name self.__metric = condition_name
@@ -158,8 +199,7 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
def _clamp_inverse_problem_params(self): def _clamp_inverse_problem_params(self):
""" """
Clamps the parameters of the inverse problem Clamps the parameters of the inverse problem solver to specified ranges.
solver to the specified ranges.
""" """
for v in self._params: for v in self._params:
self._params[v].data.clamp_( self._params[v].data.clamp_(
@@ -170,7 +210,10 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
@property @property
def loss(self): def loss(self):
""" """
Loss used for training. The loss used for training.
:return: The loss function used for training.
:rtype: torch.nn.Module
""" """
return self._loss return self._loss
@@ -178,5 +221,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
def current_condition_name(self): def current_condition_name(self):
""" """
The current condition name. The current condition name.
:return: The current condition name.
:rtype: str
""" """
return self.__metric return self.__metric

View File

@@ -1,4 +1,4 @@
"""Module for Residual-Based Attention PINN.""" """Module for the Residual-Based Attention PINN solver."""
from copy import deepcopy from copy import deepcopy
import torch import torch
@@ -9,14 +9,14 @@ from ...utils import check_consistency
class RBAPINN(PINN): class RBAPINN(PINN):
r""" r"""
Residual-based Attention PINN (RBAPINN) solver class. Residual-based Attention Physics-Informed Neural Network (RBAPINN) solver
This class implements Residual-based Attention Physics Informed Neural class. This class implements the Residual-based Attention Physics-Informed
Network solver, using a user specified ``model`` to solve a specific Neural Network solver, using a user specified ``model`` to solve a specific
``problem``. It can be used for solving both forward and inverse problems. ``problem``. It can be used to solve both forward and inverse problems.
The Residual-based Attention Physics Informed Neural Network aims to find The Residual-based Attention Physics-Informed Neural Network solver aims to
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` find the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` of a
of the differential problem: differential problem:
.. math:: .. math::
@@ -26,7 +26,7 @@ class RBAPINN(PINN):
\mathbf{x}\in\partial\Omega \mathbf{x}\in\partial\Omega
\end{cases} \end{cases}
minimizing the loss function minimizing the loss function:
.. math:: .. math::
@@ -38,23 +38,23 @@ class RBAPINN(PINN):
\left( \mathcal{B}[\mathbf{u}](\mathbf{x}) \left( \mathcal{B}[\mathbf{u}](\mathbf{x})
\right), \right),
denoting the weights as denoting the weights as:
:math:`\lambda_{\Omega}^1, \dots, \lambda_{\Omega}^{N_\Omega}` and :math:`\lambda_{\Omega}^1, \dots, \lambda_{\Omega}^{N_\Omega}` and
:math:`\lambda_{\partial \Omega}^1, \dots, :math:`\lambda_{\partial \Omega}^1, \dots,
\lambda_{\Omega}^{N_\partial \Omega}` \lambda_{\Omega}^{N_\partial \Omega}`
for :math:`\Omega` and :math:`\partial \Omega`, respectively. for :math:`\Omega` and :math:`\partial \Omega`, respectively.
Residual-based Attention Physics Informed Neural Network computes Residual-based Attention Physics-Informed Neural Network updates the weights
the weights by updating them at every epoch as follows of the residuals at every epoch as follows:
.. math:: .. math::
\lambda_i^{k+1} \leftarrow \gamma\lambda_i^{k} + \lambda_i^{k+1} \leftarrow \gamma\lambda_i^{k} +
\eta\frac{\lvert r_i\rvert}{\max_j \lvert r_j\rvert}, \eta\frac{\lvert r_i\rvert}{\max_j \lvert r_j\rvert},
where :math:`r_i` denotes the residual at point :math:`i`, where :math:`r_i` denotes the residual at point :math:`i`, :math:`\gamma`
:math:`\gamma` denotes the decay rate, and :math:`\eta` is denotes the decay rate, and :math:`\eta` is the learning rate for the
the learning rate for the weights' update. weights' update.
.. seealso:: .. seealso::
**Original reference**: Sokratis J. Anagnostopoulos, Juan D. Toscano, **Original reference**: Sokratis J. Anagnostopoulos, Juan D. Toscano,
@@ -78,20 +78,25 @@ class RBAPINN(PINN):
gamma=0.999, gamma=0.999,
): ):
""" """
:param torch.nn.Module model: The neural network model to use. Initialization of the :class:`RBAPINN` class.
:param AbstractProblem problem: The formulation of the problem.
:param torch.optim.Optimizer optimizer: The neural network optimizer to :param AbstractProblem problem: The problem to be solved.
use; default `None`. :param torch.nn.Module model: The neural network model to be used.
:param torch.optim.LRScheduler scheduler: Learning rate scheduler; :param torch.optim.Optimizer optimizer: The optimizer to be used.
default `None`. If `None`, the Adam optimizer is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to use; :param torch.optim.LRScheduler scheduler: Learning rate scheduler.
default `None`. If `None`, the constant learning rate scheduler is used.
:param torch.nn.Module loss: The loss function to be minimized; Default is ``None``.
default `None`. :param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the Mean Squared Error (MSE) loss is used.
Default is `None`.
:param float | int eta: The learning rate for the weights of the :param float | int eta: The learning rate for the weights of the
residual; default 0.001. residuals. Default is ``0.001``.
:param float gamma: The decay parameter in the update of the weights :param float gamma: The decay parameter in the update of the weights
of the residual. Must be between 0 and 1; default 0.999. of the residuals. Must be between ``0`` and ``1``.
Default is ``0.999``.
""" """
super().__init__( super().__init__(
model=model, model=model,
@@ -122,6 +127,11 @@ class RBAPINN(PINN):
# for now RBAPINN is implemented only for batch_size = None # for now RBAPINN is implemented only for batch_size = None
def on_train_start(self): def on_train_start(self):
"""
Hook method called at the beginning of training.
:raises NotImplementedError: If the batch size is not ``None``.
"""
if self.trainer.batch_size is not None: if self.trainer.batch_size is not None:
raise NotImplementedError( raise NotImplementedError(
"RBAPINN only works with full batch " "RBAPINN only works with full batch "
@@ -132,11 +142,11 @@ class RBAPINN(PINN):
def _vect_to_scalar(self, loss_value): def _vect_to_scalar(self, loss_value):
""" """
Elaboration of the pointwise loss. Computation of the scalar loss.
:param LabelTensor loss_value: the matrix of pointwise loss. :param LabelTensor loss_value: the tensor of pointwise losses.
:raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``.
:return: the scalar loss. :return: The computed scalar loss.
:rtype LabelTensor :rtype LabelTensor
""" """
if self.loss.reduction == "mean": if self.loss.reduction == "mean":
@@ -152,14 +162,12 @@ class RBAPINN(PINN):
def loss_phys(self, samples, equation): def loss_phys(self, samples, equation):
""" """
Computes the physics loss for the residual-based attention PINN Computes the physics loss for the physics-informed solver based on the
solver based on given samples and equation. provided samples and equation.
:param LabelTensor samples: The samples to evaluate the physics loss. :param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation :param EquationInterface equation: The governing equation.
representing the physics. :return: The computed physics loss.
:return: The physics loss calculated based on given
samples and equation.
:rtype: LabelTensor :rtype: LabelTensor
""" """
residual = self.compute_residual(samples=samples, equation=equation) residual = self.compute_residual(samples=samples, equation=equation)

View File

@@ -11,13 +11,15 @@ from .pinn_interface import PINNInterface
class Weights(torch.nn.Module): class Weights(torch.nn.Module):
""" """
This class aims to implements the mask model for the Implementation of the mask model for the self-adaptive weights of the
self-adaptive weights of the Self-Adaptive PINN solver. :class:`SelfAdaptivePINN` solver.
""" """
def __init__(self, func): def __init__(self, func):
""" """
:param torch.nn.Module func: the mask module of SAPINN. Initialization of the :class:`Weights` class.
:param torch.nn.Module func: the mask model.
""" """
super().__init__() super().__init__()
check_consistency(func, torch.nn.Module) check_consistency(func, torch.nn.Module)
@@ -27,7 +29,6 @@ class Weights(torch.nn.Module):
def forward(self): def forward(self):
""" """
Forward pass implementation for the mask module. Forward pass implementation for the mask module.
It returns the function on the weights evaluation.
:return: evaluation of self adaptive weights through the mask. :return: evaluation of self adaptive weights through the mask.
:rtype: torch.Tensor :rtype: torch.Tensor
@@ -37,14 +38,14 @@ class Weights(torch.nn.Module):
class SelfAdaptivePINN(PINNInterface, MultiSolverInterface): class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
r""" r"""
Self Adaptive Physics Informed Neural Network (SelfAdaptivePINN) Self-Adaptive Physics-Informed Neural Network (SelfAdaptivePINN) solver
solver class. This class implements Self-Adaptive Physics Informed Neural class. This class implements the Self-Adaptive Physics-Informed Neural
Network solver, using a user specified ``model`` to solve a specific Network solver, using a user specified ``model`` to solve a specific
``problem``. It can be used for solving both forward and inverse problems. ``problem``. It can be used to solve both forward and inverse problems.
The Self Adapive Physics Informed Neural Network aims to find The Self-Adapive Physics-Informed Neural Network solver aims to find the
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` of a differential
of the differential problem: problem:
.. math:: .. math::
@@ -54,9 +55,10 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
\mathbf{x}\in\partial\Omega \mathbf{x}\in\partial\Omega
\end{cases} \end{cases}
integrating the pointwise loss evaluation through a mask :math:`m` and integrating pointwise loss evaluation using a mask :math:m and self-adaptive
self adaptive weights that permit to focus the loss function on weights, which allow the model to focus on regions of the domain where the
specific training samples. residual is higher.
The loss function to solve the problem is The loss function to solve the problem is
.. math:: .. math::
@@ -69,24 +71,23 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
\left( \mathcal{B}[\mathbf{u}](\mathbf{x}) \left( \mathcal{B}[\mathbf{u}](\mathbf{x})
\right), \right),
denoting the self adaptive weights as denoting the self adaptive weights as
:math:`\lambda_{\Omega}^1, \dots, \lambda_{\Omega}^{N_\Omega}` and :math:`\lambda_{\Omega}^1, \dots, \lambda_{\Omega}^{N_\Omega}` and
:math:`\lambda_{\partial \Omega}^1, \dots, :math:`\lambda_{\partial \Omega}^1, \dots,
\lambda_{\Omega}^{N_\partial \Omega}` \lambda_{\Omega}^{N_\partial \Omega}`
for :math:`\Omega` and :math:`\partial \Omega`, respectively. for :math:`\Omega` and :math:`\partial \Omega`, respectively.
Self Adaptive Physics Informed Neural Network identifies the solution The Self-Adaptive Physics-Informed Neural Network solver identifies the
and appropriate self adaptive weights by solving the following problem solution and appropriate self adaptive weights by solving the following
optimization problem:
.. math:: .. math::
\min_{w} \max_{\lambda_{\Omega}^k, \lambda_{\partial \Omega}^s} \min_{w} \max_{\lambda_{\Omega}^k, \lambda_{\partial \Omega}^s}
\mathcal{L} , \mathcal{L} ,
where :math:`w` denotes the network parameters, and where :math:`w` denotes the network parameters, and :math:`\mathcal{L}` is a
:math:`\mathcal{L}` is a specific loss specific loss function, , typically the MSE:
function, default Mean Square Error:
.. math:: .. math::
\mathcal{L}(v) = \| v \|^2_2. \mathcal{L}(v) = \| v \|^2_2.
@@ -112,23 +113,29 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
loss=None, loss=None,
): ):
""" """
:param AbstractProblem problem: The formulation of the problem. Initialization of the :class:`SelfAdaptivePINN` class.
:param torch.nn.Module model: The neural network model to use for
the model. :param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module weight_function: The neural network model :param torch.nn.Module model: The model to be used.
related to the Self-Adaptive PINN mask; default `torch.nn.Sigmoid()` :param torch.nn.Module weight_function: The Self-Adaptive mask model.
:param torch.optim.Optimizer optimizer_model: The neural network Default is ``torch.nn.Sigmoid()``.
optimizer to use for the model network; default `None`. :param torch.optim.Optimizer optimizer_model: The optimizer of the
:param torch.optim.Optimizer optimizer_weights: The neural network ``model``. If `None`, the Adam optimizer is used.
optimizer to use for mask model; default `None`. Default is ``None``.
:param torch.optim.Optimizer optimizer_weights: The optimizer of the
``weight_function``. If `None`, the Adam optimizer is used.
Default is ``None``.
:param torch.optim.LRScheduler scheduler_model: Learning rate scheduler :param torch.optim.LRScheduler scheduler_model: Learning rate scheduler
for the model; default `None`. for the ``model``. If `None`, the constant learning rate scheduler
is used. Default is ``None``.
:param torch.optim.LRScheduler scheduler_weights: Learning rate :param torch.optim.LRScheduler scheduler_weights: Learning rate
scheduler for the mask model; default `None`. scheduler for the ``weight_function``. If `None`, the constant
:param WeightingInterface weighting: The weighting schema to use; learning rate scheduler is used. Default is ``None``.
default `None`. :param WeightingInterface weighting: The weighting schema to be used.
:param torch.nn.Module loss: The loss function to be minimized; If `None`, no weighting schema is used. Default is ``None``.
default `None`. :param torch.nn.Module loss: The loss function to be minimized.
If `None`, the Mean Squared Error (MSE) loss is used.
Default is `None`.
""" """
# check consistency weitghs_function # check consistency weitghs_function
check_consistency(weight_function, torch.nn.Module) check_consistency(weight_function, torch.nn.Module)
@@ -155,16 +162,11 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
self._vectorial_loss.reduction = "none" self._vectorial_loss.reduction = "none"
def forward(self, x): def forward(self, x):
r""" """
Forward pass implementation for the PINN Forward pass.
solver. It returns the function
evaluation :math:`\mathbf{u}(\mathbf{x})` at the control points
:math:`\mathbf{x}`.
:param LabelTensor x: Input tensor for the SAPINN solver. It expects :param LabelTensor x: Input tensor.
a tensor :math:`N \\times D`, where :math:`N` the number of points :return: The output of the neural network.
in the mesh, :math:`D` the dimension of the problem,
:return: PINN solution.
:rtype: LabelTensor :rtype: LabelTensor
""" """
return self.model(x) return self.model(x)
@@ -173,9 +175,8 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
""" """
Solver training step, overridden to perform manual optimization. Solver training step, overridden to perform manual optimization.
:param batch: The batch element in the dataloader. :param dict batch: The batch element in the dataloader.
:type batch: tuple :return: The aggregated loss.
:return: The sum of the loss functions.
:rtype: LabelTensor :rtype: LabelTensor
""" """
# Weights optimization # Weights optimization
@@ -194,7 +195,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
def configure_optimizers(self): def configure_optimizers(self):
""" """
Optimizer configuration for the SelfAdaptive PINN solver. Optimizer configuration.
:return: The optimizers and the schedulers :return: The optimizers and the schedulers
:rtype: tuple(list, list) :rtype: tuple(list, list)
@@ -221,16 +222,13 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
def on_train_batch_end(self, outputs, batch, batch_idx): def on_train_batch_end(self, outputs, batch, batch_idx):
""" """
This method is called at the end of each training batch, and ovverides This method is called at the end of each training batch and overrides
the PytorchLightining implementation for logging the checkpoints. the PyTorch Lightning implementation to log checkpoints.
:param torch.Tensor outputs: The output from the model for the :param torch.Tensor outputs: The ``model``'s output for the current
current batch. batch.
:param tuple batch: The current batch of data. :param dict batch: The current batch of data.
:param int batch_idx: The index of the current batch. :param int batch_idx: The index of the current batch.
:return: Whatever is returned by the parent
method ``on_train_batch_end``.
:rtype: Any
""" """
# increase by one the counter of optimization to save loggers # increase by one the counter of optimization to save loggers
( (
@@ -241,12 +239,10 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
def on_train_start(self): def on_train_start(self):
""" """
This method is called at the start of the training for setting This method is called at the start of the training process to set the
the self adaptive weights as parameters of the mask model. self-adaptive weights as parameters of the mask model.
:return: Whatever is returned by the parent :raises NotImplementedError: If the batch size is not ``None``.
method ``on_train_start``.
:rtype: Any
""" """
if self.trainer.batch_size is not None: if self.trainer.batch_size is not None:
raise NotImplementedError( raise NotImplementedError(
@@ -270,9 +266,9 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
def on_load_checkpoint(self, checkpoint): def on_load_checkpoint(self, checkpoint):
""" """
Override the Pytorch Lightning ``on_load_checkpoint`` to handle Override of the Pytorch Lightning ``on_load_checkpoint`` method to
checkpoints for Self-Adaptive Weights. This method should not be handle checkpoints for Self-Adaptive Weights. This method should not be
overridden if not intentionally. overridden, if not intentionally.
:param dict checkpoint: Pytorch Lightning checkpoint dict. :param dict checkpoint: Pytorch Lightning checkpoint dict.
""" """
@@ -289,14 +285,13 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
def loss_phys(self, samples, equation): def loss_phys(self, samples, equation):
""" """
Computation of the physical loss for SelfAdaptive PINN solver. Computes the physics loss for the physics-informed solver based on the
provided samples and equation.
:param LabelTensor samples: Input samples to evaluate the physics loss. :param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: the governing equation representing :param EquationInterface equation: The governing equation.
the physics. :return: The computed physics loss.
:rtype: LabelTensor
:return: tuple with weighted and not weighted scalar loss
:rtype: List[LabelTensor, LabelTensor]
""" """
residual = self.compute_residual(samples, equation) residual = self.compute_residual(samples, equation)
weights = self.weights_dict[self.current_condition_name].forward() weights = self.weights_dict[self.current_condition_name].forward()
@@ -307,12 +302,11 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
def _vect_to_scalar(self, loss_value): def _vect_to_scalar(self, loss_value):
""" """
Elaboration of the pointwise loss through the mask model and the Computation of the scalar loss.
self adaptive weights
:param LabelTensor loss_value: the matrix of pointwise loss :param LabelTensor loss_value: the tensor of pointwise losses.
:raises RuntimeError: If the loss reduction is not ``mean`` or ``sum``.
:return: the scalar loss :return: The computed scalar loss.
:rtype LabelTensor :rtype LabelTensor
""" """
if self.loss.reduction == "mean": if self.loss.reduction == "mean":
@@ -329,33 +323,29 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
@property @property
def model(self): def model(self):
""" """
Return the mask models associate to the application of The model.
the mask to the self adaptive weights for each loss that
compones the global loss of the problem.
:return: The ModuleDict for mask models. :return: The model.
:rtype: torch.nn.ModuleDict :rtype: torch.nn.Module
""" """
return self.models[0] return self.models[0]
@property @property
def weights_dict(self): def weights_dict(self):
""" """
Return the mask models associate to the application of The self-adaptive weights.
the mask to the self adaptive weights for each loss that
compones the global loss of the problem.
:return: The ModuleDict for mask models. :return: The self-adaptive weights.
:rtype: torch.nn.ModuleDict :rtype: torch.nn.Module
""" """
return self.models[1] return self.models[1]
@property @property
def scheduler_model(self): def scheduler_model(self):
""" """
Returns the scheduler associated with the neural network model. The scheduler associated to the model.
:return: The scheduler for the neural network model. :return: The scheduler for the model.
:rtype: torch.optim.lr_scheduler._LRScheduler :rtype: torch.optim.lr_scheduler._LRScheduler
""" """
return self.schedulers[0] return self.schedulers[0]
@@ -363,7 +353,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
@property @property
def scheduler_weights(self): def scheduler_weights(self):
""" """
Returns the scheduler associated with the mask model (if applicable). The scheduler associated to the mask model.
:return: The scheduler for the mask model. :return: The scheduler for the mask model.
:rtype: torch.optim.lr_scheduler._LRScheduler :rtype: torch.optim.lr_scheduler._LRScheduler
@@ -373,9 +363,9 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
@property @property
def optimizer_model(self): def optimizer_model(self):
""" """
Returns the optimizer associated with the neural network model. Returns the optimizer associated to the model.
:return: The optimizer for the neural network model. :return: The optimizer for the model.
:rtype: torch.optim.Optimizer :rtype: torch.optim.Optimizer
""" """
return self.optimizers[0] return self.optimizers[0]
@@ -383,7 +373,7 @@ class SelfAdaptivePINN(PINNInterface, MultiSolverInterface):
@property @property
def optimizer_weights(self): def optimizer_weights(self):
""" """
Returns the optimizer associated with the mask model (if applicable). The optimizer associated to the mask model.
:return: The optimizer for the mask model. :return: The optimizer for the mask model.
:rtype: torch.optim.Optimizer :rtype: torch.optim.Optimizer