fix pinn doc

This commit is contained in:
giovanni
2025-03-13 16:50:05 +01:00
committed by Nicola Demo
parent 9a26c94e07
commit 28ef4c823b
8 changed files with 377 additions and 337 deletions

View File

@@ -1,4 +1,4 @@
"""Module for Competitive PINN."""
"""Module for the Competitive PINN solver."""
import copy
import torch
@@ -10,14 +10,14 @@ from ..solver import MultiSolverInterface
class CompetitivePINN(PINNInterface, MultiSolverInterface):
r"""
Competitive Physics Informed Neural Network (PINN) solver class.
This class implements Competitive Physics Informed Neural
Network solver, using a user specified ``model`` to solve a specific
``problem``. It can be used for solving both forward and inverse problems.
Competitive Physics-Informed Neural Network (CompetitivePINN) solver class.
This class implements the Competitive Physics-Informed Neural Network
solver, using a user specified ``model`` to solve a specific ``problem``.
It can be used to solve both forward and inverse problems.
The Competitive Physics Informed Network aims to find
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
of the differential problem:
The Competitive Physics-Informed Neural Network solver aims to find the
solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` of a differential
problem:
.. math::
@@ -27,18 +27,18 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
\mathbf{x}\in\partial\Omega
\end{cases}
with a minimization (on ``model`` parameters) maximation (
on ``discriminator`` parameters) of the loss function
minimizing the loss function with respect to the model parameters, while
maximizing it with respect to the discriminator parameters:
.. math::
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
\mathcal{L}(D(\mathbf{x}_i)\mathcal{A}[\mathbf{u}](\mathbf{x}_i))+
\frac{1}{N}\sum_{i=1}^N
\mathcal{L}(D(\mathbf{x}_i)\mathcal{B}[\mathbf{u}](\mathbf{x}_i))
\mathcal{L}(D(\mathbf{x}_i)\mathcal{B}[\mathbf{u}](\mathbf{x}_i)),
where :math:`D` is the discriminator network, which tries to find the points
where the network performs worst, and :math:`\mathcal{L}` is a specific loss
function, default Mean Square Error:
where :math:D is the discriminator network, which identifies the points
where the model performs worst, and :math:\mathcal{L} is a specific loss
function, typically the MSE:
.. math::
\mathcal{L}(v) = \| v \|^2_2.
@@ -49,10 +49,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
"Competitive physics informed networks." International Conference on
Learning Representations, ICLR 2022
`OpenReview Preprint <https://openreview.net/forum?id=z9SIj-IM7tn>`_.
.. warning::
This solver does not currently support the possibility to pass
``extra_feature``.
"""
def __init__(
@@ -68,24 +64,30 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
loss=None,
):
"""
:param AbstractProblem problem: The formulation of the problem.
:param torch.nn.Module model: The neural network model to use
for the model.
:param torch.nn.Module discriminator: The neural network model to use
for the discriminator. If ``None``, the discriminator network will
have the same architecture as the model network.
:param torch.optim.Optimizer optimizer_model: The neural network
optimizer to use for the model network; default `None`.
:param torch.optim.Optimizer optimizer_discriminator: The neural network
optimizer to use for the discriminator network; default `None`.
Initialization of the :class:`CompetitivePINN` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module model: The neural network model to be used.
:param torch.nn.Module discriminator: The discriminator to be used.
If `None`, the discriminator is a deepcopy of the ``model``.
Default is ``None``.
:param torch.optim.Optimizer optimizer_model: The optimizer of the
``model``. If `None`, the Adam optimizer is used.
Default is ``None``.
:param torch.optim.Optimizer optimizer_discriminator: The optimizer of
the ``discriminator``. If `None`, the Adam optimizer is used.
Default is ``None``.
:param torch.optim.LRScheduler scheduler_model: Learning rate scheduler
for the model; default `None`.
for the ``model``. If `None`, the constant learning rate scheduler
is used. Default is ``None``.
:param torch.optim.LRScheduler scheduler_discriminator: Learning rate
scheduler for the discriminator; default `None`.
:param WeightingInterface weighting: The weighting schema to use;
default `None`.
:param torch.nn.Module loss: The loss function to be minimized;
default `None`.
scheduler for the ``discriminator``. If `None`, the constant
learning rate scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the Mean Squared Error (MSE) loss is used.
Default is `None`.
"""
if discriminator is None:
discriminator = copy.deepcopy(model)
@@ -103,15 +105,11 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
self.automatic_optimization = False
def forward(self, x):
r"""
Forward pass implementation for the PINN solver. It returns the function
evaluation :math:`\mathbf{u}(\mathbf{x})` at the control points
:math:`\mathbf{x}`.
"""
Forward pass.
:param LabelTensor x: Input tensor for the PINN solver. It expects
a tensor :math:`N \times D`, where :math:`N` the number of points
in the mesh, :math:`D` the dimension of the problem,
:return: PINN solution evaluated at contro points.
:param LabelTensor x: Input tensor.
:return: The output of the neural network.
:rtype: LabelTensor
"""
return self.neural_net(x)
@@ -120,9 +118,8 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
"""
Solver training step, overridden to perform manual optimization.
:param batch: The batch element in the dataloader.
:type batch: tuple
:return: The sum of the loss functions.
:param dict batch: The batch element in the dataloader.
:return: The aggregated loss.
:rtype: LabelTensor
"""
# train model
@@ -139,14 +136,12 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
def loss_phys(self, samples, equation):
"""
Computes the physics loss for the Competitive PINN solver based on given
samples and equation.
Computes the physics loss for the physics-informed solver based on the
provided samples and equation.
:param LabelTensor samples: The samples to evaluate the physics loss.
:param EquationInterface equation: The governing equation
representing the physics.
:return: The physics loss calculated based on given
samples and equation.
:param EquationInterface equation: The governing equation.
:return: The computed physics loss.
:rtype: LabelTensor
"""
# Compute discriminator bets
@@ -165,7 +160,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
def configure_optimizers(self):
"""
Optimizer configuration for the Competitive PINN solver.
Optimizer configuration.
:return: The optimizers and the schedulers
:rtype: tuple(list, list)
@@ -198,16 +193,13 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
def on_train_batch_end(self, outputs, batch, batch_idx):
"""
This method is called at the end of each training batch, and ovverides
the PytorchLightining implementation for logging the checkpoints.
This method is called at the end of each training batch and overrides
the PyTorch Lightning implementation to log checkpoints.
:param torch.Tensor outputs: The output from the model for the
current batch.
:param tuple batch: The current batch of data.
:param torch.Tensor outputs: The ``model``'s output for the current
batch.
:param dict batch: The current batch of data.
:param int batch_idx: The index of the current batch.
:return: Whatever is returned by the parent
method ``on_train_batch_end``.
:rtype: Any
"""
# increase by one the counter of optimization to save loggers
(
@@ -219,9 +211,9 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
@property
def neural_net(self):
"""
Returns the neural network model.
The model.
:return: The neural network model.
:return: The model.
:rtype: torch.nn.Module
"""
return self.models[0]
@@ -229,9 +221,9 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
@property
def discriminator(self):
"""
Returns the discriminator model (if applicable).
The discriminator.
:return: The discriminator model.
:return: The discriminator.
:rtype: torch.nn.Module
"""
return self.models[1]
@@ -239,9 +231,9 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
@property
def optimizer_model(self):
"""
Returns the optimizer associated with the neural network model.
The optimizer associated to the model.
:return: The optimizer for the neural network model.
:return: The optimizer for the model.
:rtype: torch.optim.Optimizer
"""
return self.optimizers[0]
@@ -249,7 +241,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
@property
def optimizer_discriminator(self):
"""
Returns the optimizer associated with the discriminator (if applicable).
The optimizer associated to the discriminator.
:return: The optimizer for the discriminator.
:rtype: torch.optim.Optimizer
@@ -259,9 +251,9 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
@property
def scheduler_model(self):
"""
Returns the scheduler associated with the neural network model.
The scheduler associated to the model.
:return: The scheduler for the neural network model.
:return: The scheduler for the model.
:rtype: torch.optim.lr_scheduler._LRScheduler
"""
return self.schedulers[0]
@@ -269,7 +261,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
@property
def scheduler_discriminator(self):
"""
Returns the scheduler associated with the discriminator (if applicable).
The scheduler associated to the discriminator.
:return: The scheduler for the discriminator.
:rtype: torch.optim.lr_scheduler._LRScheduler