fix pinn doc
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
"""Module for Competitive PINN."""
|
||||
"""Module for the Competitive PINN solver."""
|
||||
|
||||
import copy
|
||||
import torch
|
||||
@@ -10,14 +10,14 @@ from ..solver import MultiSolverInterface
|
||||
|
||||
class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
r"""
|
||||
Competitive Physics Informed Neural Network (PINN) solver class.
|
||||
This class implements Competitive Physics Informed Neural
|
||||
Network solver, using a user specified ``model`` to solve a specific
|
||||
``problem``. It can be used for solving both forward and inverse problems.
|
||||
Competitive Physics-Informed Neural Network (CompetitivePINN) solver class.
|
||||
This class implements the Competitive Physics-Informed Neural Network
|
||||
solver, using a user specified ``model`` to solve a specific ``problem``.
|
||||
It can be used to solve both forward and inverse problems.
|
||||
|
||||
The Competitive Physics Informed Network aims to find
|
||||
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
|
||||
of the differential problem:
|
||||
The Competitive Physics-Informed Neural Network solver aims to find the
|
||||
solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` of a differential
|
||||
problem:
|
||||
|
||||
.. math::
|
||||
|
||||
@@ -27,18 +27,18 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
\mathbf{x}\in\partial\Omega
|
||||
\end{cases}
|
||||
|
||||
with a minimization (on ``model`` parameters) maximation (
|
||||
on ``discriminator`` parameters) of the loss function
|
||||
minimizing the loss function with respect to the model parameters, while
|
||||
maximizing it with respect to the discriminator parameters:
|
||||
|
||||
.. math::
|
||||
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
|
||||
\mathcal{L}(D(\mathbf{x}_i)\mathcal{A}[\mathbf{u}](\mathbf{x}_i))+
|
||||
\frac{1}{N}\sum_{i=1}^N
|
||||
\mathcal{L}(D(\mathbf{x}_i)\mathcal{B}[\mathbf{u}](\mathbf{x}_i))
|
||||
\mathcal{L}(D(\mathbf{x}_i)\mathcal{B}[\mathbf{u}](\mathbf{x}_i)),
|
||||
|
||||
where :math:`D` is the discriminator network, which tries to find the points
|
||||
where the network performs worst, and :math:`\mathcal{L}` is a specific loss
|
||||
function, default Mean Square Error:
|
||||
where :math:D is the discriminator network, which identifies the points
|
||||
where the model performs worst, and :math:\mathcal{L} is a specific loss
|
||||
function, typically the MSE:
|
||||
|
||||
.. math::
|
||||
\mathcal{L}(v) = \| v \|^2_2.
|
||||
@@ -49,10 +49,6 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
"Competitive physics informed networks." International Conference on
|
||||
Learning Representations, ICLR 2022
|
||||
`OpenReview Preprint <https://openreview.net/forum?id=z9SIj-IM7tn>`_.
|
||||
|
||||
.. warning::
|
||||
This solver does not currently support the possibility to pass
|
||||
``extra_feature``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -68,24 +64,30 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
loss=None,
|
||||
):
|
||||
"""
|
||||
:param AbstractProblem problem: The formulation of the problem.
|
||||
:param torch.nn.Module model: The neural network model to use
|
||||
for the model.
|
||||
:param torch.nn.Module discriminator: The neural network model to use
|
||||
for the discriminator. If ``None``, the discriminator network will
|
||||
have the same architecture as the model network.
|
||||
:param torch.optim.Optimizer optimizer_model: The neural network
|
||||
optimizer to use for the model network; default `None`.
|
||||
:param torch.optim.Optimizer optimizer_discriminator: The neural network
|
||||
optimizer to use for the discriminator network; default `None`.
|
||||
Initialization of the :class:`CompetitivePINN` class.
|
||||
|
||||
:param AbstractProblem problem: The problem to be solved.
|
||||
:param torch.nn.Module model: The neural network model to be used.
|
||||
:param torch.nn.Module discriminator: The discriminator to be used.
|
||||
If `None`, the discriminator is a deepcopy of the ``model``.
|
||||
Default is ``None``.
|
||||
:param torch.optim.Optimizer optimizer_model: The optimizer of the
|
||||
``model``. If `None`, the Adam optimizer is used.
|
||||
Default is ``None``.
|
||||
:param torch.optim.Optimizer optimizer_discriminator: The optimizer of
|
||||
the ``discriminator``. If `None`, the Adam optimizer is used.
|
||||
Default is ``None``.
|
||||
:param torch.optim.LRScheduler scheduler_model: Learning rate scheduler
|
||||
for the model; default `None`.
|
||||
for the ``model``. If `None`, the constant learning rate scheduler
|
||||
is used. Default is ``None``.
|
||||
:param torch.optim.LRScheduler scheduler_discriminator: Learning rate
|
||||
scheduler for the discriminator; default `None`.
|
||||
:param WeightingInterface weighting: The weighting schema to use;
|
||||
default `None`.
|
||||
:param torch.nn.Module loss: The loss function to be minimized;
|
||||
default `None`.
|
||||
scheduler for the ``discriminator``. If `None`, the constant
|
||||
learning rate scheduler is used. Default is ``None``.
|
||||
:param WeightingInterface weighting: The weighting schema to be used.
|
||||
If `None`, no weighting schema is used. Default is ``None``.
|
||||
:param torch.nn.Module loss: The loss function to be minimized.
|
||||
If `None`, the Mean Squared Error (MSE) loss is used.
|
||||
Default is `None`.
|
||||
"""
|
||||
if discriminator is None:
|
||||
discriminator = copy.deepcopy(model)
|
||||
@@ -103,15 +105,11 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
self.automatic_optimization = False
|
||||
|
||||
def forward(self, x):
|
||||
r"""
|
||||
Forward pass implementation for the PINN solver. It returns the function
|
||||
evaluation :math:`\mathbf{u}(\mathbf{x})` at the control points
|
||||
:math:`\mathbf{x}`.
|
||||
"""
|
||||
Forward pass.
|
||||
|
||||
:param LabelTensor x: Input tensor for the PINN solver. It expects
|
||||
a tensor :math:`N \times D`, where :math:`N` the number of points
|
||||
in the mesh, :math:`D` the dimension of the problem,
|
||||
:return: PINN solution evaluated at contro points.
|
||||
:param LabelTensor x: Input tensor.
|
||||
:return: The output of the neural network.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
return self.neural_net(x)
|
||||
@@ -120,9 +118,8 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
"""
|
||||
Solver training step, overridden to perform manual optimization.
|
||||
|
||||
:param batch: The batch element in the dataloader.
|
||||
:type batch: tuple
|
||||
:return: The sum of the loss functions.
|
||||
:param dict batch: The batch element in the dataloader.
|
||||
:return: The aggregated loss.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
# train model
|
||||
@@ -139,14 +136,12 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
|
||||
def loss_phys(self, samples, equation):
|
||||
"""
|
||||
Computes the physics loss for the Competitive PINN solver based on given
|
||||
samples and equation.
|
||||
Computes the physics loss for the physics-informed solver based on the
|
||||
provided samples and equation.
|
||||
|
||||
:param LabelTensor samples: The samples to evaluate the physics loss.
|
||||
:param EquationInterface equation: The governing equation
|
||||
representing the physics.
|
||||
:return: The physics loss calculated based on given
|
||||
samples and equation.
|
||||
:param EquationInterface equation: The governing equation.
|
||||
:return: The computed physics loss.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
# Compute discriminator bets
|
||||
@@ -165,7 +160,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""
|
||||
Optimizer configuration for the Competitive PINN solver.
|
||||
Optimizer configuration.
|
||||
|
||||
:return: The optimizers and the schedulers
|
||||
:rtype: tuple(list, list)
|
||||
@@ -198,16 +193,13 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
|
||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
||||
"""
|
||||
This method is called at the end of each training batch, and ovverides
|
||||
the PytorchLightining implementation for logging the checkpoints.
|
||||
This method is called at the end of each training batch and overrides
|
||||
the PyTorch Lightning implementation to log checkpoints.
|
||||
|
||||
:param torch.Tensor outputs: The output from the model for the
|
||||
current batch.
|
||||
:param tuple batch: The current batch of data.
|
||||
:param torch.Tensor outputs: The ``model``'s output for the current
|
||||
batch.
|
||||
:param dict batch: The current batch of data.
|
||||
:param int batch_idx: The index of the current batch.
|
||||
:return: Whatever is returned by the parent
|
||||
method ``on_train_batch_end``.
|
||||
:rtype: Any
|
||||
"""
|
||||
# increase by one the counter of optimization to save loggers
|
||||
(
|
||||
@@ -219,9 +211,9 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
@property
|
||||
def neural_net(self):
|
||||
"""
|
||||
Returns the neural network model.
|
||||
The model.
|
||||
|
||||
:return: The neural network model.
|
||||
:return: The model.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self.models[0]
|
||||
@@ -229,9 +221,9 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
@property
|
||||
def discriminator(self):
|
||||
"""
|
||||
Returns the discriminator model (if applicable).
|
||||
The discriminator.
|
||||
|
||||
:return: The discriminator model.
|
||||
:return: The discriminator.
|
||||
:rtype: torch.nn.Module
|
||||
"""
|
||||
return self.models[1]
|
||||
@@ -239,9 +231,9 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
@property
|
||||
def optimizer_model(self):
|
||||
"""
|
||||
Returns the optimizer associated with the neural network model.
|
||||
The optimizer associated to the model.
|
||||
|
||||
:return: The optimizer for the neural network model.
|
||||
:return: The optimizer for the model.
|
||||
:rtype: torch.optim.Optimizer
|
||||
"""
|
||||
return self.optimizers[0]
|
||||
@@ -249,7 +241,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
@property
|
||||
def optimizer_discriminator(self):
|
||||
"""
|
||||
Returns the optimizer associated with the discriminator (if applicable).
|
||||
The optimizer associated to the discriminator.
|
||||
|
||||
:return: The optimizer for the discriminator.
|
||||
:rtype: torch.optim.Optimizer
|
||||
@@ -259,9 +251,9 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
@property
|
||||
def scheduler_model(self):
|
||||
"""
|
||||
Returns the scheduler associated with the neural network model.
|
||||
The scheduler associated to the model.
|
||||
|
||||
:return: The scheduler for the neural network model.
|
||||
:return: The scheduler for the model.
|
||||
:rtype: torch.optim.lr_scheduler._LRScheduler
|
||||
"""
|
||||
return self.schedulers[0]
|
||||
@@ -269,7 +261,7 @@ class CompetitivePINN(PINNInterface, MultiSolverInterface):
|
||||
@property
|
||||
def scheduler_discriminator(self):
|
||||
"""
|
||||
Returns the scheduler associated with the discriminator (if applicable).
|
||||
The scheduler associated to the discriminator.
|
||||
|
||||
:return: The scheduler for the discriminator.
|
||||
:rtype: torch.optim.lr_scheduler._LRScheduler
|
||||
|
||||
Reference in New Issue
Block a user