fix doc solver
This commit is contained in:
@@ -1,6 +1,4 @@
|
|||||||
"""
|
"""Module for the solver classes."""
|
||||||
TODO
|
|
||||||
"""
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SolverInterface",
|
"SolverInterface",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Module for GAROM"""
|
"""Module for the GAROM solver."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
@@ -10,9 +10,9 @@ from ..loss import LossInterface, PowerLoss
|
|||||||
|
|
||||||
class GAROM(MultiSolverInterface):
|
class GAROM(MultiSolverInterface):
|
||||||
"""
|
"""
|
||||||
GAROM solver class. This class implements Generative Adversarial
|
GAROM solver class. This class implements Generative Adversarial Reduced
|
||||||
Reduced Order Model solver, using user specified ``models`` to solve
|
Order Model solver, using user specified ``models`` to solve a specific
|
||||||
a specific order reduction``problem``.
|
order reduction ``problem``.
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
@@ -39,40 +39,28 @@ class GAROM(MultiSolverInterface):
|
|||||||
regularizer=False,
|
regularizer=False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param AbstractProblem problem: The formualation of the problem.
|
Initialization of the :class:`GAROM` class.
|
||||||
:param torch.nn.Module generator: The neural network model to use
|
|
||||||
for the generator.
|
|
||||||
:param torch.nn.Module discriminator: The neural network model to use
|
|
||||||
for the discriminator.
|
|
||||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
|
||||||
default ``None``. If ``loss`` is ``None`` the defualt
|
|
||||||
``PowerLoss(p=1)`` is used, as in the original paper.
|
|
||||||
:param Optimizer optimizer_generator: The neural
|
|
||||||
network optimizer to use for the generator network
|
|
||||||
, default is `torch.optim.Adam`.
|
|
||||||
:param Optimizer optimizer_discriminator: The neural
|
|
||||||
network optimizer to use for the discriminator network
|
|
||||||
, default is `torch.optim.Adam`.
|
|
||||||
:param Scheduler scheduler_generator: Learning
|
|
||||||
rate scheduler for the generator.
|
|
||||||
:param Scheduler scheduler_discriminator: Learning
|
|
||||||
rate scheduler for the discriminator.
|
|
||||||
:param dict scheduler_discriminator_kwargs: LR scheduler constructor
|
|
||||||
keyword args.
|
|
||||||
:param gamma: Ratio of expected loss for generator and discriminator,
|
|
||||||
defaults to 0.3.
|
|
||||||
:type gamma: float
|
|
||||||
:param lambda_k: Learning rate for control theory optimization,
|
|
||||||
defaults to 0.001.
|
|
||||||
:type lambda_k: float
|
|
||||||
:param regularizer: Regularization term in the GAROM loss,
|
|
||||||
defaults to False.
|
|
||||||
:type regularizer: bool
|
|
||||||
|
|
||||||
.. warning::
|
:param AbstractProblem problem: The formulation of the problem.
|
||||||
The algorithm works only for data-driven model. Hence in the
|
:param torch.nn.Module generator: The generator model.
|
||||||
``problem`` definition the codition must only contain ``input``
|
:param torch.nn.Module discriminator: The discriminator model.
|
||||||
(e.g. coefficient parameters, time parameters), and ``target``.
|
:param torch.nn.Module loss: The loss function to be minimized.
|
||||||
|
If ``None``, ``PowerLoss(p=1)`` is used. Default is ``None``.
|
||||||
|
:param Optimizer optimizer_generator: The optimizer for the generator.
|
||||||
|
If `None`, the Adam optimizer is used. Default is ``None``.
|
||||||
|
:param Optimizer optimizer_discriminator: The optimizer for the
|
||||||
|
discriminator. If `None`, the Adam optimizer is used.
|
||||||
|
Default is ``None``.
|
||||||
|
:param Scheduler scheduler_generator: The learning rate scheduler for
|
||||||
|
the generator.
|
||||||
|
:param Scheduler scheduler_discriminator: The learning rate scheduler
|
||||||
|
for the discriminator.
|
||||||
|
:param float gamma: Ratio of expected loss for generator and
|
||||||
|
discriminator. Default is ``0.3``.
|
||||||
|
:param float lambda_k: Learning rate for control theory optimization.
|
||||||
|
Default is ``0.001``.
|
||||||
|
:param bool regularizer: If ``True``, uses a regularization term in the
|
||||||
|
GAROM loss. Default is ``False``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# set loss
|
# set loss
|
||||||
@@ -112,19 +100,15 @@ class GAROM(MultiSolverInterface):
|
|||||||
|
|
||||||
def forward(self, x, mc_steps=20, variance=False):
|
def forward(self, x, mc_steps=20, variance=False):
|
||||||
"""
|
"""
|
||||||
Forward step for GAROM solver
|
Forward pass implementation.
|
||||||
|
|
||||||
:param x: The input tensor.
|
:param torch.Tensor x: The input tensor.
|
||||||
:type x: torch.Tensor
|
:param int mc_steps: Number of Montecarlo samples to approximate the
|
||||||
:param mc_steps: Number of montecarlo samples to approximate the
|
expected value. Default is ``20``.
|
||||||
expected value, defaults to 20.
|
:param bool variance: If ``True``, the method returns also the variance
|
||||||
:type mc_steps: int
|
of the solution. Default is ``False``.
|
||||||
:param variance: Returining also the sample variance of the solution,
|
|
||||||
defaults to False.
|
|
||||||
:type variance: bool
|
|
||||||
:return: The expected value of the generator distribution. If
|
:return: The expected value of the generator distribution. If
|
||||||
``variance=True`` also the
|
``variance=True``, the method returns also the variance.
|
||||||
sample variance is returned.
|
|
||||||
:rtype: torch.Tensor | tuple(torch.Tensor, torch.Tensor)
|
:rtype: torch.Tensor | tuple(torch.Tensor, torch.Tensor)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -142,13 +126,24 @@ class GAROM(MultiSolverInterface):
|
|||||||
return mean
|
return mean
|
||||||
|
|
||||||
def sample(self, x):
|
def sample(self, x):
|
||||||
"""TODO"""
|
"""
|
||||||
|
Sample from the generator distribution.
|
||||||
|
|
||||||
|
:param torch.Tensor x: The input tensor.
|
||||||
|
:return: The generated sample.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
# sampling
|
# sampling
|
||||||
return self.generator(x)
|
return self.generator(x)
|
||||||
|
|
||||||
def _train_generator(self, parameters, snapshots):
|
def _train_generator(self, parameters, snapshots):
|
||||||
"""
|
"""
|
||||||
Private method to train the generator network.
|
Train the generator model.
|
||||||
|
|
||||||
|
:param torch.Tensor parameters: The input tensor.
|
||||||
|
:param torch.Tensor snapshots: The target tensor.
|
||||||
|
:return: The residual loss and the generator loss.
|
||||||
|
:rtype: tuple(torch.Tensor, torch.Tensor)
|
||||||
"""
|
"""
|
||||||
optimizer = self.optimizer_generator
|
optimizer = self.optimizer_generator
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@@ -170,16 +165,13 @@ class GAROM(MultiSolverInterface):
|
|||||||
|
|
||||||
def on_train_batch_end(self, outputs, batch, batch_idx):
|
def on_train_batch_end(self, outputs, batch, batch_idx):
|
||||||
"""
|
"""
|
||||||
This method is called at the end of each training batch, and ovverides
|
This method is called at the end of each training batch and overrides
|
||||||
the PytorchLightining implementation for logging the checkpoints.
|
the PyTorch Lightning implementation to log checkpoints.
|
||||||
|
|
||||||
:param torch.Tensor outputs: The output from the model for the
|
:param torch.Tensor outputs: The ``model``'s output for the current
|
||||||
current batch.
|
batch.
|
||||||
:param tuple batch: The current batch of data.
|
:param dict batch: The current batch of data.
|
||||||
:param int batch_idx: The index of the current batch.
|
:param int batch_idx: The index of the current batch.
|
||||||
:return: Whatever is returned by the parent
|
|
||||||
method ``on_train_batch_end``.
|
|
||||||
:rtype: Any
|
|
||||||
"""
|
"""
|
||||||
# increase by one the counter of optimization to save loggers
|
# increase by one the counter of optimization to save loggers
|
||||||
(
|
(
|
||||||
@@ -190,7 +182,12 @@ class GAROM(MultiSolverInterface):
|
|||||||
|
|
||||||
def _train_discriminator(self, parameters, snapshots):
|
def _train_discriminator(self, parameters, snapshots):
|
||||||
"""
|
"""
|
||||||
Private method to train the discriminator network.
|
Train the discriminator model.
|
||||||
|
|
||||||
|
:param torch.Tensor parameters: The input tensor.
|
||||||
|
:param torch.Tensor snapshots: The target tensor.
|
||||||
|
:return: The residual loss and the generator loss.
|
||||||
|
:rtype: tuple(torch.Tensor, torch.Tensor)
|
||||||
"""
|
"""
|
||||||
optimizer = self.optimizer_discriminator
|
optimizer = self.optimizer_discriminator
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
@@ -215,8 +212,15 @@ class GAROM(MultiSolverInterface):
|
|||||||
|
|
||||||
def _update_weights(self, d_loss_real, d_loss_fake):
|
def _update_weights(self, d_loss_real, d_loss_fake):
|
||||||
"""
|
"""
|
||||||
Private method to Update the weights of the generator and discriminator
|
Update the weights of the generator and discriminator models.
|
||||||
networks.
|
|
||||||
|
:param torch.Tensor d_loss_real: The discriminator loss computed on
|
||||||
|
dataset samples.
|
||||||
|
:param torch.Tensor d_loss_fake: The discriminator loss computed on
|
||||||
|
generated samples.
|
||||||
|
:return: The difference between the loss computed on the dataset samples
|
||||||
|
and the loss computed on the generated samples.
|
||||||
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
|
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
|
||||||
@@ -227,11 +231,11 @@ class GAROM(MultiSolverInterface):
|
|||||||
return diff
|
return diff
|
||||||
|
|
||||||
def optimization_cycle(self, batch):
|
def optimization_cycle(self, batch):
|
||||||
"""GAROM solver training step.
|
"""
|
||||||
|
The optimization cycle for the GAROM solver.
|
||||||
|
|
||||||
:param batch: The batch element in the dataloader.
|
:param tuple batch: The batch element in the dataloader.
|
||||||
:type batch: tuple
|
:return: The loss of the optimization cycle.
|
||||||
:return: The sum of the loss functions.
|
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
condition_loss = {}
|
condition_loss = {}
|
||||||
@@ -258,6 +262,13 @@ class GAROM(MultiSolverInterface):
|
|||||||
return condition_loss
|
return condition_loss
|
||||||
|
|
||||||
def validation_step(self, batch):
|
def validation_step(self, batch):
|
||||||
|
"""
|
||||||
|
The validation step for the PINN solver.
|
||||||
|
|
||||||
|
:param dict batch: The batch of data to use in the validation step.
|
||||||
|
:return: The loss of the validation step.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
condition_loss = {}
|
condition_loss = {}
|
||||||
for condition_name, points in batch:
|
for condition_name, points in batch:
|
||||||
parameters, snapshots = (
|
parameters, snapshots = (
|
||||||
@@ -273,6 +284,13 @@ class GAROM(MultiSolverInterface):
|
|||||||
return loss
|
return loss
|
||||||
|
|
||||||
def test_step(self, batch):
|
def test_step(self, batch):
|
||||||
|
"""
|
||||||
|
The test step for the PINN solver.
|
||||||
|
|
||||||
|
:param dict batch: The batch of data to use in the test step.
|
||||||
|
:return: The loss of the test step.
|
||||||
|
:rtype: torch.Tensor
|
||||||
|
"""
|
||||||
condition_loss = {}
|
condition_loss = {}
|
||||||
for condition_name, points in batch:
|
for condition_name, points in batch:
|
||||||
parameters, snapshots = (
|
parameters, snapshots = (
|
||||||
@@ -289,30 +307,60 @@ class GAROM(MultiSolverInterface):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def generator(self):
|
def generator(self):
|
||||||
"""TODO"""
|
"""
|
||||||
|
The generator model.
|
||||||
|
|
||||||
|
:return: The generator model.
|
||||||
|
:rtype: torch.nn.Module
|
||||||
|
"""
|
||||||
return self.models[0]
|
return self.models[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def discriminator(self):
|
def discriminator(self):
|
||||||
"""TODO"""
|
"""
|
||||||
|
The discriminator model.
|
||||||
|
|
||||||
|
:return: The discriminator model.
|
||||||
|
:rtype: torch.nn.Module
|
||||||
|
"""
|
||||||
return self.models[1]
|
return self.models[1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def optimizer_generator(self):
|
def optimizer_generator(self):
|
||||||
"""TODO"""
|
"""
|
||||||
|
The optimizer for the generator.
|
||||||
|
|
||||||
|
:return: The optimizer for the generator.
|
||||||
|
:rtype: Optimizer
|
||||||
|
"""
|
||||||
return self.optimizers[0].instance
|
return self.optimizers[0].instance
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def optimizer_discriminator(self):
|
def optimizer_discriminator(self):
|
||||||
"""TODO"""
|
"""
|
||||||
|
The optimizer for the discriminator.
|
||||||
|
|
||||||
|
:return: The optimizer for the discriminator.
|
||||||
|
:rtype: Optimizer
|
||||||
|
"""
|
||||||
return self.optimizers[1].instance
|
return self.optimizers[1].instance
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scheduler_generator(self):
|
def scheduler_generator(self):
|
||||||
"""TODO"""
|
"""
|
||||||
|
The scheduler for the generator.
|
||||||
|
|
||||||
|
:return: The scheduler for the generator.
|
||||||
|
:rtype: Scheduler
|
||||||
|
"""
|
||||||
return self.schedulers[0].instance
|
return self.schedulers[0].instance
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scheduler_discriminator(self):
|
def scheduler_discriminator(self):
|
||||||
"""TODO"""
|
"""
|
||||||
|
The scheduler for the discriminator.
|
||||||
|
|
||||||
|
:return: The scheduler for the discriminator.
|
||||||
|
:rtype: Scheduler
|
||||||
|
"""
|
||||||
return self.schedulers[1].instance
|
return self.schedulers[1].instance
|
||||||
|
|||||||
@@ -110,8 +110,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
def loss_data(self, input_pts, output_pts):
|
def loss_data(self, input_pts, output_pts):
|
||||||
"""
|
"""
|
||||||
Compute the data loss for the PINN solver by evaluating the loss
|
Compute the data loss for the PINN solver by evaluating the loss
|
||||||
between the network's output and the true solution. This method
|
between the network's output and the true solution. This method should
|
||||||
should only be overridden intentionally.
|
not be overridden, if not intentionally.
|
||||||
|
|
||||||
:param LabelTensor input_pts: The input points to the neural network.
|
:param LabelTensor input_pts: The input points to the neural network.
|
||||||
:param LabelTensor output_pts: The true solution to compare with the
|
:param LabelTensor output_pts: The true solution to compare with the
|
||||||
|
|||||||
@@ -1,19 +1,17 @@
|
|||||||
"""Module for ReducedOrderModelSolver"""
|
"""Module for ReducedOrderModelSolver"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from .supervised import SupervisedSolver
|
from .supervised import SupervisedSolver
|
||||||
|
|
||||||
|
|
||||||
class ReducedOrderModelSolver(SupervisedSolver):
|
class ReducedOrderModelSolver(SupervisedSolver):
|
||||||
r"""
|
r"""
|
||||||
ReducedOrderModelSolver solver class. This class implements a
|
Reduced Order Model solver class. This class implements the Reduced Order
|
||||||
Reduced Order Model solver, using user specified ``reduction_network`` and
|
Model solver, using user specified ``reduction_network`` and
|
||||||
``interpolation_network`` to solve a specific ``problem``.
|
``interpolation_network`` to solve a specific ``problem``.
|
||||||
|
|
||||||
The Reduced Order Model approach aims to find
|
The Reduced Order Model solver aims to find the solution
|
||||||
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
|
:math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` of a differential problem:
|
||||||
of the differential problem:
|
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
||||||
@@ -23,13 +21,13 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
|||||||
\mathbf{x}\in\partial\Omega
|
\mathbf{x}\in\partial\Omega
|
||||||
\end{cases}
|
\end{cases}
|
||||||
|
|
||||||
This is done by using two neural networks. The ``reduction_network``, which
|
This is done by means of two neural networks: the ``reduction_network``,
|
||||||
contains an encoder :math:`\mathcal{E}_{\rm{net}}`, a decoder
|
which defines an encoder :math:`\mathcal{E}_{\rm{net}}`, and a decoder
|
||||||
:math:`\mathcal{D}_{\rm{net}}`; and an ``interpolation_network``
|
:math:`\mathcal{D}_{\rm{net}}`; and the ``interpolation_network``
|
||||||
:math:`\mathcal{I}_{\rm{net}}`. The input is assumed to be discretised in
|
:math:`\mathcal{I}_{\rm{net}}`. The input is assumed to be discretised in
|
||||||
the spatial dimensions.
|
the spatial dimensions.
|
||||||
|
|
||||||
The following loss function is minimized during training
|
The following loss function is minimized during training:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
|
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
|
||||||
@@ -39,49 +37,46 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
|||||||
\mathcal{D}_{\rm{net}}[\mathcal{E}_{\rm{net}}[\mathbf{u}(\mu_i)]] -
|
\mathcal{D}_{\rm{net}}[\mathcal{E}_{\rm{net}}[\mathbf{u}(\mu_i)]] -
|
||||||
\mathbf{u}(\mu_i))
|
\mathbf{u}(\mu_i))
|
||||||
|
|
||||||
where :math:`\mathcal{L}` is a specific loss function, default
|
where :math:`\mathcal{L}` is a specific loss function, typically the MSE:
|
||||||
Mean Square Error:
|
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\mathcal{L}(v) = \| v \|^2_2.
|
\mathcal{L}(v) = \| v \|^2_2.
|
||||||
|
|
||||||
|
|
||||||
.. seealso::
|
.. seealso::
|
||||||
|
|
||||||
**Original reference**: Hesthaven, Jan S., and Stefano Ubbiali.
|
**Original reference**: Hesthaven, Jan S., and Stefano Ubbiali.
|
||||||
"Non-intrusive reduced order modeling of nonlinear problems
|
"Non-intrusive reduced order modeling of nonlinear problems using
|
||||||
using neural networks." Journal of Computational
|
neural networks."
|
||||||
Physics 363 (2018): 55-78.
|
Journal of Computational Physics 363 (2018): 55-78.
|
||||||
DOI `10.1016/j.jcp.2018.02.037
|
DOI `10.1016/j.jcp.2018.02.037
|
||||||
<https://doi.org/10.1016/j.jcp.2018.02.037>`_.
|
<https://doi.org/10.1016/j.jcp.2018.02.037>`_.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
The specified ``reduction_network`` must contain two methods,
|
The specified ``reduction_network`` must contain two methods, namely
|
||||||
namely ``encode`` for input encoding and ``decode`` for decoding the
|
``encode`` for input encoding, and ``decode`` for decoding the former
|
||||||
former result. The ``interpolation_network`` network ``forward`` output
|
result. The ``interpolation_network`` network ``forward`` output
|
||||||
represents the interpolation of the latent space obtain with
|
represents the interpolation of the latent space obtained with
|
||||||
``reduction_network.encode``.
|
``reduction_network.encode``.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
This solver uses the end-to-end training strategy, i.e. the
|
This solver uses the end-to-end training strategy, i.e. the
|
||||||
``reduction_network`` and ``interpolation_network`` are trained
|
``reduction_network`` and ``interpolation_network`` are trained
|
||||||
simultaneously. For reference on this trainig strategy look at:
|
simultaneously. For reference on this trainig strategy look at the
|
||||||
Pichi, Federico, Beatriz Moya, and Jan S. Hesthaven.
|
following:
|
||||||
|
|
||||||
|
..seealso::
|
||||||
|
**Original reference**: Pichi, Federico, Beatriz Moya, and Jan S.
|
||||||
|
Hesthaven.
|
||||||
"A graph convolutional autoencoder approach to model order reduction
|
"A graph convolutional autoencoder approach to model order reduction
|
||||||
for parametrized PDEs." Journal of
|
for parametrized PDEs."
|
||||||
Computational Physics 501 (2024): 112762.
|
Journal of Computational Physics 501 (2024): 112762.
|
||||||
DOI
|
DOI `10.1016/j.jcp.2024.112762
|
||||||
`10.1016/j.jcp.2024.112762 <https://doi.org/10.1016/
|
<https://doi.org/10.1016/j.jcp.2024.112762>`_.
|
||||||
j.jcp.2024.112762>`_.
|
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
This solver works only for data-driven model. Hence in the ``problem``
|
This solver works only for data-driven model. Hence in the ``problem``
|
||||||
definition the codition must only contain ``input``
|
definition the codition must only contain ``input``
|
||||||
(e.g. coefficient parameters, time parameters), and ``target``.
|
(e.g. coefficient parameters, time parameters), and ``target``.
|
||||||
|
|
||||||
.. warning::
|
|
||||||
This solver does not currently support the possibility to pass
|
|
||||||
``extra_feature``.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -96,22 +91,28 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
|||||||
use_lt=True,
|
use_lt=True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
Initialization of the :class:`ReducedOrderModelSolver` class.
|
||||||
|
|
||||||
:param AbstractProblem problem: The formualation of the problem.
|
:param AbstractProblem problem: The formualation of the problem.
|
||||||
:param torch.nn.Module reduction_network: The reduction network used
|
:param torch.nn.Module reduction_network: The reduction network used
|
||||||
for reducing the input space. It must contain two methods,
|
for reducing the input space. It must contain two methods, namely
|
||||||
namely ``encode`` for input encoding and ``decode`` for decoding the
|
``encode`` for input encoding, and ``decode`` for decoding the
|
||||||
former result.
|
former result.
|
||||||
:param torch.nn.Module interpolation_network: The interpolation network
|
:param torch.nn.Module interpolation_network: The interpolation network
|
||||||
for interpolating the control parameters to latent space obtain by
|
for interpolating the control parameters to latent space obtained by
|
||||||
the ``reduction_network`` encoding.
|
the ``reduction_network`` encoding.
|
||||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
:param torch.nn.Module loss: The loss function to be minimized.
|
||||||
default :class:`torch.nn.MSELoss`.
|
If `None`, the :class:`torch.nn.MSELoss` loss is used.
|
||||||
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
Default is `None`.
|
||||||
use; default is :class:`torch.optim.Adam`.
|
:param Optimizer optimizer: The optimizer to be used.
|
||||||
:param torch.optim.LRScheduler scheduler: Learning
|
If `None`, the :class:`torch.optim.Adam`. optimizer is used.
|
||||||
rate scheduler.
|
Default is ``None``.
|
||||||
:param WeightingInterface weighting: The loss weighting to use.
|
:param Scheduler scheduler: Learning rate scheduler. If `None`,
|
||||||
:param bool use_lt: Using LabelTensors as input during training.
|
the constant learning rate scheduler is used. Default is ``None``.
|
||||||
|
:param WeightingInterface weighting: The weighting schema to be used.
|
||||||
|
If `None`, no weighting schema is used. Default is ``None``.
|
||||||
|
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
|
||||||
|
Default is ``True``.
|
||||||
"""
|
"""
|
||||||
model = torch.nn.ModuleDict(
|
model = torch.nn.ModuleDict(
|
||||||
{
|
{
|
||||||
@@ -146,10 +147,10 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
Forward pass implementation for the solver. It finds the encoder
|
Forward pass implementation.
|
||||||
representation by calling ``interpolation_network.forward`` on the
|
It computes the encoder representation by calling the forward method
|
||||||
input, and maps this representation to output space by calling
|
of the ``interpolation_network`` on the input, and maps it to output
|
||||||
``reduction_network.decode``.
|
space by calling the decode methode of the ``reduction_network``.
|
||||||
|
|
||||||
:param torch.Tensor x: Input tensor.
|
:param torch.Tensor x: Input tensor.
|
||||||
:return: Solver solution.
|
:return: Solver solution.
|
||||||
@@ -161,15 +162,14 @@ class ReducedOrderModelSolver(SupervisedSolver):
|
|||||||
|
|
||||||
def loss_data(self, input_pts, output_pts):
|
def loss_data(self, input_pts, output_pts):
|
||||||
"""
|
"""
|
||||||
The data loss for the ReducedOrderModelSolver solver.
|
Compute the data loss by evaluating the loss between the network's
|
||||||
It computes the loss between
|
output and the true solution. This method should not be overridden, if
|
||||||
the network output against the true solution. This function
|
not intentionally.
|
||||||
should not be override if not intentionally.
|
|
||||||
|
|
||||||
:param LabelTensor input_tensor: The input to the neural networks.
|
:param LabelTensor input_pts: The input points to the neural network.
|
||||||
:param LabelTensor output_tensor: The true solution to compare the
|
:param LabelTensor output_pts: The true solution to compare with the
|
||||||
network solution.
|
network's output.
|
||||||
:return: The residual loss averaged on the input coordinates
|
:return: The supervised loss, averaged over the number of observations.
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
# extract networks
|
# extract networks
|
||||||
|
|||||||
@@ -14,17 +14,19 @@ from ..utils import check_consistency, labelize_forward
|
|||||||
|
|
||||||
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
SolverInterface base class. This class is a wrapper of LightningModule.
|
Abstract base class for PINA solvers. All specific solvers should inherit
|
||||||
|
from this interface. This class is a wrapper of
|
||||||
|
:class:`~lightning.pytorch.LightningModule`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, problem, weighting, use_lt):
|
def __init__(self, problem, weighting, use_lt):
|
||||||
"""
|
"""
|
||||||
:param problem: A problem definition instance.
|
Initialization of the :class:`SolverInterface` class.
|
||||||
:type problem: AbstractProblem
|
|
||||||
:param weighting: The loss weighting to use.
|
:param AbstractProblem problem: The problem to be solved.
|
||||||
:type weighting: WeightingInterface
|
:param WeightingInterface weighting: The weighting schema to be used.
|
||||||
:param use_lt: Using LabelTensors as input during training.
|
If `None`, no weighting schema is used. Default is ``None``.
|
||||||
:type use_lt: bool
|
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -59,22 +61,24 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
|||||||
self._pina_schedulers = None
|
self._pina_schedulers = None
|
||||||
|
|
||||||
def _check_solver_consistency(self, problem):
|
def _check_solver_consistency(self, problem):
|
||||||
|
"""
|
||||||
|
Check the consistency of the solver with the problem formulation.
|
||||||
|
|
||||||
|
:param AbstractProblem problem: The problem to be solved.
|
||||||
|
"""
|
||||||
for condition in problem.conditions.values():
|
for condition in problem.conditions.values():
|
||||||
check_consistency(condition, self.accepted_conditions_types)
|
check_consistency(condition, self.accepted_conditions_types)
|
||||||
|
|
||||||
def _optimization_cycle(self, batch):
|
def _optimization_cycle(self, batch):
|
||||||
"""
|
"""
|
||||||
Perform a private optimization cycle by computing the loss for each
|
Aggregate the loss for each condition in the batch.
|
||||||
condition in the given batch. The loss are later aggregated using the
|
|
||||||
specific weighting schema.
|
|
||||||
|
|
||||||
:param batch: A batch of data, where each element is a tuple containing
|
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
|
||||||
a condition name and a dictionary of points.
|
tuple containing a condition name and a dictionary of points.
|
||||||
:type batch: list of tuples (str, dict)
|
:return: The computed loss for the all conditions in the batch, casted
|
||||||
:return: The computed loss for the all conditions in the batch,
|
to a subclass of `torch.Tensor`. It should return a dict containing
|
||||||
cast to a subclass of `torch.Tensor`. It should return a dict
|
the condition name and the associated scalar loss.
|
||||||
containing the condition name and the associated scalar loss.
|
:rtype: dict
|
||||||
:rtype: dict(torch.Tensor)
|
|
||||||
"""
|
"""
|
||||||
losses = self.optimization_cycle(batch)
|
losses = self.optimization_cycle(batch)
|
||||||
for name, value in losses.items():
|
for name, value in losses.items():
|
||||||
@@ -88,9 +92,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
Solver training step.
|
Solver training step.
|
||||||
|
|
||||||
:param batch: The batch element in the dataloader.
|
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
|
||||||
:type batch: tuple
|
:return: The loss of the training step.
|
||||||
:return: The sum of the loss functions.
|
|
||||||
:rtype: LabelTensor
|
:rtype: LabelTensor
|
||||||
"""
|
"""
|
||||||
loss = self._optimization_cycle(batch=batch)
|
loss = self._optimization_cycle(batch=batch)
|
||||||
@@ -101,8 +104,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
Solver validation step.
|
Solver validation step.
|
||||||
|
|
||||||
:param batch: The batch element in the dataloader.
|
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
|
||||||
:type batch: tuple
|
|
||||||
"""
|
"""
|
||||||
loss = self._optimization_cycle(batch=batch)
|
loss = self._optimization_cycle(batch=batch)
|
||||||
self.store_log("val_loss", loss, self.get_batch_size(batch))
|
self.store_log("val_loss", loss, self.get_batch_size(batch))
|
||||||
@@ -111,15 +113,18 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
Solver test step.
|
Solver test step.
|
||||||
|
|
||||||
:param batch: The batch element in the dataloader.
|
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
|
||||||
:type batch: tuple
|
|
||||||
"""
|
"""
|
||||||
loss = self._optimization_cycle(batch=batch)
|
loss = self._optimization_cycle(batch=batch)
|
||||||
self.store_log("test_loss", loss, self.get_batch_size(batch))
|
self.store_log("test_loss", loss, self.get_batch_size(batch))
|
||||||
|
|
||||||
def store_log(self, name, value, batch_size):
|
def store_log(self, name, value, batch_size):
|
||||||
"""
|
"""
|
||||||
TODO
|
Store the log of the solver.
|
||||||
|
|
||||||
|
:param str name: The name of the log.
|
||||||
|
:param torch.Tensor value: The value of the log.
|
||||||
|
:param int batch_size: The size of the batch.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.log(
|
self.log(
|
||||||
@@ -132,49 +137,59 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
TODO
|
Abstract method for the forward pass implementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def optimization_cycle(self, batch):
|
def optimization_cycle(self, batch):
|
||||||
"""
|
"""
|
||||||
Perform an optimization cycle by computing the loss for each condition
|
The optimization cycle for the solvers.
|
||||||
in the given batch.
|
|
||||||
|
|
||||||
:param batch: A batch of data, where each element is a tuple containing
|
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
|
||||||
a condition name and a dictionary of points.
|
:return: The computed loss for the all conditions in the batch, casted
|
||||||
:type batch: list of tuples (str, dict)
|
to a subclass of `torch.Tensor`. It should return a dict containing
|
||||||
:return: The computed loss for the all conditions in the batch,
|
the condition name and the associated scalar loss.
|
||||||
cast to a subclass of `torch.Tensor`. It should return a dict
|
:rtype: dict
|
||||||
containing the condition name and the associated scalar loss.
|
|
||||||
:rtype: dict(torch.Tensor)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def problem(self):
|
def problem(self):
|
||||||
"""
|
"""
|
||||||
The problem formulation.
|
The problem instance.
|
||||||
|
|
||||||
|
:return: The problem instance.
|
||||||
|
:rtype: :class:`~pina.problem.abstract_problem.AbstractProblem`
|
||||||
"""
|
"""
|
||||||
return self._pina_problem
|
return self._pina_problem
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def use_lt(self):
|
def use_lt(self):
|
||||||
"""
|
"""
|
||||||
Using LabelTensor in training.
|
Using LabelTensors as input during training.
|
||||||
|
|
||||||
|
:return: The use_lt attribute.
|
||||||
|
:rtype: bool
|
||||||
"""
|
"""
|
||||||
return self._use_lt
|
return self._use_lt
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def weighting(self):
|
def weighting(self):
|
||||||
"""
|
"""
|
||||||
The weighting mechanism.
|
The weighting schema.
|
||||||
|
|
||||||
|
:return: The weighting schema.
|
||||||
|
:rtype: :class:`~pina.loss.weighting_interface.WeightingInterface`
|
||||||
"""
|
"""
|
||||||
return self._pina_weighting
|
return self._pina_weighting
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_batch_size(batch):
|
def get_batch_size(batch):
|
||||||
"""
|
"""
|
||||||
TODO
|
Get the batch size.
|
||||||
|
|
||||||
|
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
|
||||||
|
:return: The size of the batch.
|
||||||
|
:rtype: int
|
||||||
"""
|
"""
|
||||||
|
|
||||||
batch_size = 0
|
batch_size = 0
|
||||||
@@ -185,23 +200,29 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def default_torch_optimizer():
|
def default_torch_optimizer():
|
||||||
"""
|
"""
|
||||||
TODO
|
Set the default optimizer to :class:`torch.optim.Adam`.
|
||||||
"""
|
|
||||||
|
|
||||||
|
:return: The default optimizer.
|
||||||
|
:rtype: Optimizer
|
||||||
|
"""
|
||||||
return TorchOptimizer(torch.optim.Adam, lr=0.001)
|
return TorchOptimizer(torch.optim.Adam, lr=0.001)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def default_torch_scheduler():
|
def default_torch_scheduler():
|
||||||
"""
|
"""
|
||||||
TODO
|
Set the default scheduler to
|
||||||
|
:class:`torch.optim.lr_scheduler.ConstantLR`.
|
||||||
|
|
||||||
|
:return: The default scheduler.
|
||||||
|
:rtype: Scheduler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
|
return TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
|
||||||
|
|
||||||
def on_train_start(self):
|
def on_train_start(self):
|
||||||
"""
|
"""
|
||||||
Hook that is called before training begins.
|
This method is called at the start of the training process to compile
|
||||||
Used to compile the model if the trainer is set to compile.
|
the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``.
|
||||||
"""
|
"""
|
||||||
super().on_train_start()
|
super().on_train_start()
|
||||||
if self.trainer.compile:
|
if self.trainer.compile:
|
||||||
@@ -209,8 +230,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
|||||||
|
|
||||||
def on_test_start(self):
|
def on_test_start(self):
|
||||||
"""
|
"""
|
||||||
Hook that is called before training begins.
|
This method is called at the start of the test process to compile
|
||||||
Used to compile the model if the trainer is set to compile.
|
the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``.
|
||||||
"""
|
"""
|
||||||
super().on_train_start()
|
super().on_train_start()
|
||||||
if self.trainer.compile and not self._check_already_compiled():
|
if self.trainer.compile and not self._check_already_compiled():
|
||||||
@@ -218,7 +239,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
|||||||
|
|
||||||
def _check_already_compiled(self):
|
def _check_already_compiled(self):
|
||||||
"""
|
"""
|
||||||
TODO
|
Check if the model is already compiled.
|
||||||
|
|
||||||
|
:return: ``True`` if the model is already compiled, ``False`` otherwise.
|
||||||
|
:rtype: bool
|
||||||
"""
|
"""
|
||||||
|
|
||||||
models = self._pina_models
|
models = self._pina_models
|
||||||
@@ -234,7 +258,12 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _perform_compilation(model):
|
def _perform_compilation(model):
|
||||||
"""
|
"""
|
||||||
TODO
|
Perform the compilation of the model.
|
||||||
|
|
||||||
|
:param torch.nn.Module model: The model to compile.
|
||||||
|
:raises Exception: If the compilation fails.
|
||||||
|
:return: The compiled model.
|
||||||
|
:rtype: torch.nn.Module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_device = next(model.parameters()).device
|
model_device = next(model.parameters()).device
|
||||||
@@ -249,8 +278,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
|
|||||||
|
|
||||||
|
|
||||||
class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||||
"""TODO"""
|
"""
|
||||||
|
Base class for PINA solvers using a single :class:`torch.nn.Module`.
|
||||||
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
problem,
|
problem,
|
||||||
@@ -261,14 +291,18 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
use_lt=True,
|
use_lt=True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param problem: A problem definition instance.
|
Initialization of the :class:`SingleSolverInterface` class.
|
||||||
:type problem: AbstractProblem
|
|
||||||
:param model: A torch nn.Module instances.
|
:param AbstractProblem problem: The problem to be solved.
|
||||||
:type model: torch.nn.Module
|
:param torch.nn.Module model: The neural network model to be used.
|
||||||
:param Optimizer optimizers: A neural network optimizers to use.
|
:param Optimizer optimizer: The optimizer to be used.
|
||||||
:param Scheduler optimizers: A neural network scheduler to use.
|
If `None`, the Adam optimizer is used. Default is ``None``.
|
||||||
:param WeightingInterface weighting: The loss weighting to use.
|
:param Scheduler scheduler: The scheduler to be used.
|
||||||
:param bool use_lt: Using LabelTensors as input during training.
|
If `None`, the constant learning rate scheduler is used.
|
||||||
|
Default is ``None``.
|
||||||
|
:param WeightingInterface weighting: The weighting schema to be used.
|
||||||
|
If `None`, no weighting schema is used. Default is ``None``.
|
||||||
|
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
|
||||||
"""
|
"""
|
||||||
if optimizer is None:
|
if optimizer is None:
|
||||||
optimizer = self.default_torch_optimizer()
|
optimizer = self.default_torch_optimizer()
|
||||||
@@ -292,11 +326,12 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
"""
|
"""
|
||||||
Forward pass implementation for the solver.
|
Forward pass implementation.
|
||||||
|
|
||||||
:param torch.Tensor x: Input tensor.
|
:param x: Input tensor.
|
||||||
|
:type x: torch.Tensor | LabelTensor
|
||||||
:return: Solver solution.
|
:return: Solver solution.
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor | LabelTensor
|
||||||
"""
|
"""
|
||||||
x = self.model(x)
|
x = self.model(x)
|
||||||
return x
|
return x
|
||||||
@@ -305,7 +340,7 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
Optimizer configuration for the solver.
|
Optimizer configuration for the solver.
|
||||||
|
|
||||||
:return: The optimizers and the schedulers
|
:return: The optimizer and the scheduler
|
||||||
:rtype: tuple(list, list)
|
:rtype: tuple(list, list)
|
||||||
"""
|
"""
|
||||||
self.optimizer.hook(self.model.parameters())
|
self.optimizer.hook(self.model.parameters())
|
||||||
@@ -313,44 +348,61 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
return ([self.optimizer.instance], [self.scheduler.instance])
|
return ([self.optimizer.instance], [self.scheduler.instance])
|
||||||
|
|
||||||
def _compile_model(self):
|
def _compile_model(self):
|
||||||
|
"""
|
||||||
|
Compile the model.
|
||||||
|
"""
|
||||||
if isinstance(self._pina_models[0], torch.nn.ModuleDict):
|
if isinstance(self._pina_models[0], torch.nn.ModuleDict):
|
||||||
self._compile_module_dict()
|
self._compile_module_dict()
|
||||||
else:
|
else:
|
||||||
self._compile_single_model()
|
self._compile_single_model()
|
||||||
|
|
||||||
def _compile_module_dict(self):
|
def _compile_module_dict(self):
|
||||||
|
"""
|
||||||
|
Compile the model if it is a :class:`torch.nn.ModuleDict`.
|
||||||
|
"""
|
||||||
for name, model in self._pina_models[0].items():
|
for name, model in self._pina_models[0].items():
|
||||||
self._pina_models[0][name] = self._perform_compilation(model)
|
self._pina_models[0][name] = self._perform_compilation(model)
|
||||||
|
|
||||||
def _compile_single_model(self):
|
def _compile_single_model(self):
|
||||||
|
"""
|
||||||
|
Compile the model if it is a single :class:`torch.nn.Module`.
|
||||||
|
"""
|
||||||
self._pina_models[0] = self._perform_compilation(self._pina_models[0])
|
self._pina_models[0] = self._perform_compilation(self._pina_models[0])
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self):
|
def model(self):
|
||||||
"""
|
"""
|
||||||
Model for training.
|
The model used for training.
|
||||||
|
|
||||||
|
:return: The model used for training.
|
||||||
|
:rtype: torch.nn.Module
|
||||||
"""
|
"""
|
||||||
return self._pina_models[0]
|
return self._pina_models[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scheduler(self):
|
def scheduler(self):
|
||||||
"""
|
"""
|
||||||
Scheduler for training.
|
The scheduler used for training.
|
||||||
|
|
||||||
|
:return: The scheduler used for training.
|
||||||
|
:rtype: Scheduler
|
||||||
"""
|
"""
|
||||||
return self._pina_schedulers[0]
|
return self._pina_schedulers[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def optimizer(self):
|
def optimizer(self):
|
||||||
"""
|
"""
|
||||||
Optimizer for training.
|
The optimizer used for training.
|
||||||
|
|
||||||
|
:return: The optimizer used for training.
|
||||||
|
:rtype: Optimizer
|
||||||
"""
|
"""
|
||||||
return self._pina_optimizers[0]
|
return self._pina_optimizers[0]
|
||||||
|
|
||||||
|
|
||||||
class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
Multiple Solver base class. This class inherits is a wrapper of
|
Base class for PINA solvers using multiple :class:`torch.nn.Module`.
|
||||||
SolverInterface class
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -363,16 +415,22 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
use_lt=True,
|
use_lt=True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param problem: A problem definition instance.
|
Initialization of the :class:`MultiSolverInterface` class.
|
||||||
:type problem: AbstractProblem
|
|
||||||
:param models: Multiple torch nn.Module instances.
|
:param AbstractProblem problem: The problem to be solved.
|
||||||
|
:param models: The neural network models to be used.
|
||||||
:type model: list[torch.nn.Module] | tuple[torch.nn.Module]
|
:type model: list[torch.nn.Module] | tuple[torch.nn.Module]
|
||||||
:param list(Optimizer) optimizers: A list of neural network
|
:param list[Optimizer] optimizers: The optimizers to be used.
|
||||||
optimizers to use.
|
If `None`, the Adam optimizer is used for all models.
|
||||||
:param list(Scheduler) optimizers: A list of neural network
|
Default is ``None``.
|
||||||
schedulers to use.
|
:param list[Scheduler] schedulers: The schedulers to be used.
|
||||||
:param WeightingInterface weighting: The loss weighting to use.
|
If `None`, the constant learning rate scheduler is used for all the
|
||||||
:param bool use_lt: Using LabelTensors as input during training.
|
models. Default is ``None``.
|
||||||
|
:param WeightingInterface weighting: The weighting schema to be used.
|
||||||
|
If `None`, no weighting schema is used. Default is ``None``.
|
||||||
|
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
|
||||||
|
:raises ValueError: If the models are not a list or tuple with length
|
||||||
|
greater than one.
|
||||||
"""
|
"""
|
||||||
if not isinstance(models, (list, tuple)) or len(models) < 2:
|
if not isinstance(models, (list, tuple)) or len(models) < 2:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -418,9 +476,10 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
self._pina_schedulers = schedulers
|
self._pina_schedulers = schedulers
|
||||||
|
|
||||||
def configure_optimizers(self):
|
def configure_optimizers(self):
|
||||||
"""Optimizer configuration for the solver.
|
"""
|
||||||
|
Optimizer configuration for the solver.
|
||||||
|
|
||||||
:return: The optimizers and the schedulers
|
:return: The optimizer and the scheduler
|
||||||
:rtype: tuple(list, list)
|
:rtype: tuple(list, list)
|
||||||
"""
|
"""
|
||||||
for optimizer, scheduler, model in zip(
|
for optimizer, scheduler, model in zip(
|
||||||
@@ -435,6 +494,9 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _compile_model(self):
|
def _compile_model(self):
|
||||||
|
"""
|
||||||
|
Compile the model.
|
||||||
|
"""
|
||||||
for i, model in enumerate(self._pina_models):
|
for i, model in enumerate(self._pina_models):
|
||||||
if not isinstance(model, torch.nn.ModuleDict):
|
if not isinstance(model, torch.nn.ModuleDict):
|
||||||
self._pina_models[i] = self._perform_compilation(model)
|
self._pina_models[i] = self._perform_compilation(model)
|
||||||
@@ -442,17 +504,29 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
|
|||||||
@property
|
@property
|
||||||
def models(self):
|
def models(self):
|
||||||
"""
|
"""
|
||||||
The torch model."""
|
The models used for training.
|
||||||
|
|
||||||
|
:return: The models used for training.
|
||||||
|
:rtype: torch.nn.ModuleList
|
||||||
|
"""
|
||||||
return self._pina_models
|
return self._pina_models
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def optimizers(self):
|
def optimizers(self):
|
||||||
"""
|
"""
|
||||||
The torch model."""
|
The optimizers used for training.
|
||||||
|
|
||||||
|
:return: The optimizers used for training.
|
||||||
|
:rtype: list[Optimizer]
|
||||||
|
"""
|
||||||
return self._pina_optimizers
|
return self._pina_optimizers
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def schedulers(self):
|
def schedulers(self):
|
||||||
"""
|
"""
|
||||||
The torch model."""
|
The schedulers used for training.
|
||||||
|
|
||||||
|
:return: The schedulers used for training.
|
||||||
|
:rtype: list[Scheduler]
|
||||||
|
"""
|
||||||
return self._pina_schedulers
|
return self._pina_schedulers
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Module for SupervisedSolver"""
|
"""Module for the Supervised Solver."""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.modules.loss import _Loss
|
from torch.nn.modules.loss import _Loss
|
||||||
@@ -10,31 +10,28 @@ from ..condition import InputTargetCondition
|
|||||||
|
|
||||||
class SupervisedSolver(SingleSolverInterface):
|
class SupervisedSolver(SingleSolverInterface):
|
||||||
r"""
|
r"""
|
||||||
SupervisedSolver solver class. This class implements a SupervisedSolver,
|
Supervised Solver solver class. This class implements a Supervised Solver,
|
||||||
using a user specified ``model`` to solve a specific ``problem``.
|
using a user specified ``model`` to solve a specific ``problem``.
|
||||||
|
|
||||||
The Supervised Solver class aims to find
|
The Supervised Solver class aims to find a map between the input
|
||||||
a map between the input :math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m`
|
:math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m` and the output
|
||||||
and the output :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`. The input
|
:math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`.
|
||||||
can be discretised in space (as in :obj:`~pina.solver.rom.ROMe2eSolver`),
|
|
||||||
or not (e.g. when training Neural Operators).
|
|
||||||
|
|
||||||
Given a model :math:`\mathcal{M}`, the following loss function is
|
Given a model :math:`\mathcal{M}`, the following loss function is
|
||||||
minimized during training:
|
minimized during training:
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
|
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
|
||||||
\mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{v}_i))
|
\mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{v}_i)),
|
||||||
|
|
||||||
where :math:`\mathcal{L}` is a specific loss function,
|
where :math:`\mathcal{L}` is a specific loss function, typically the MSE:
|
||||||
default Mean Square Error:
|
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\mathcal{L}(v) = \| v \|^2_2.
|
\mathcal{L}(v) = \| v \|^2_2.
|
||||||
|
|
||||||
In this context :math:`\mathbf{u}_i` and :math:`\mathbf{v}_i` means that
|
In this context, :math:`\mathbf{u}_i` and :math:`\mathbf{v}_i` indicates
|
||||||
we are seeking to approximate multiple (discretised) functions given
|
the will to approximate multiple (discretised) functions given multiple
|
||||||
multiple (discretised) input functions.
|
(discretised) input functions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
accepted_conditions_types = InputTargetCondition
|
accepted_conditions_types = InputTargetCondition
|
||||||
@@ -50,16 +47,22 @@ class SupervisedSolver(SingleSolverInterface):
|
|||||||
use_lt=True,
|
use_lt=True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
:param AbstractProblem problem: The formualation of the problem.
|
Initialization of the :class:`SupervisedSolver` class.
|
||||||
:param torch.nn.Module model: The neural network model to use.
|
|
||||||
:param torch.nn.Module loss: The loss function used as minimizer,
|
:param AbstractProblem problem: The problem to be solved.
|
||||||
default :class:`torch.nn.MSELoss`.
|
:param torch.nn.Module model: The neural network model to be used.
|
||||||
:param torch.optim.Optimizer optimizer: The neural network optimizer to
|
:param torch.nn.Module loss: The loss function to be minimized.
|
||||||
use; default is :class:`torch.optim.Adam`.
|
If `None`, the Mean Squared Error (MSE) loss is used.
|
||||||
:param torch.optim.LRScheduler scheduler: Learning
|
Default is `None`.
|
||||||
rate scheduler.
|
:param torch.optim.Optimizer optimizer: The optimizer to be used.
|
||||||
:param WeightingInterface weighting: The loss weighting to use.
|
If `None`, the Adam optimizer is used. Default is ``None``.
|
||||||
:param bool use_lt: Using LabelTensors as input during training.
|
:param torch.optim.LRScheduler scheduler: Learning rate scheduler.
|
||||||
|
If `None`, the constant learning rate scheduler is used.
|
||||||
|
Default is ``None``.
|
||||||
|
:param WeightingInterface weighting: The weighting schema to be used.
|
||||||
|
If `None`, no weighting schema is used. Default is ``None``.
|
||||||
|
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
|
||||||
|
Default is ``True``.
|
||||||
"""
|
"""
|
||||||
if loss is None:
|
if loss is None:
|
||||||
loss = torch.nn.MSELoss()
|
loss = torch.nn.MSELoss()
|
||||||
@@ -81,16 +84,13 @@ class SupervisedSolver(SingleSolverInterface):
|
|||||||
|
|
||||||
def optimization_cycle(self, batch):
|
def optimization_cycle(self, batch):
|
||||||
"""
|
"""
|
||||||
Perform an optimization cycle by computing the loss for each condition
|
The optimization cycle for the solvers.
|
||||||
in the given batch.
|
|
||||||
|
|
||||||
:param batch: A batch of data, where each element is a tuple containing
|
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
|
||||||
a condition name and a dictionary of points.
|
:return: The computed loss for the all conditions in the batch, casted
|
||||||
:type batch: list of tuples (str, dict)
|
to a subclass of `torch.Tensor`. It should return a dict containing
|
||||||
:return: The computed loss for the all conditions in the batch,
|
the condition name and the associated scalar loss.
|
||||||
cast to a subclass of `torch.Tensor`. It should return a dict
|
:rtype: dict
|
||||||
containing the condition name and the associated scalar loss.
|
|
||||||
:rtype: dict(torch.Tensor)
|
|
||||||
"""
|
"""
|
||||||
condition_loss = {}
|
condition_loss = {}
|
||||||
for condition_name, points in batch:
|
for condition_name, points in batch:
|
||||||
@@ -105,16 +105,16 @@ class SupervisedSolver(SingleSolverInterface):
|
|||||||
|
|
||||||
def loss_data(self, input_pts, output_pts):
|
def loss_data(self, input_pts, output_pts):
|
||||||
"""
|
"""
|
||||||
The data loss for the Supervised solver. It computes the loss between
|
Compute the data loss for the Supervised solver by evaluating the loss
|
||||||
the network output against the true solution. This function
|
between the network's output and the true solution. This method should
|
||||||
should not be override if not intentionally.
|
not be overridden, if not intentionally.
|
||||||
|
|
||||||
:param input_pts: The input to the neural networks.
|
:param input_pts: The input points to the neural network.
|
||||||
:type input_pts: LabelTensor | torch.Tensor
|
:type input_pts: LabelTensor | torch.Tensor
|
||||||
:param output_pts: The true solution to compare the
|
:param output_pts: The true solution to compare with the network's
|
||||||
network solution.
|
output.
|
||||||
:type output_pts: LabelTensor | torch.Tensor
|
:type output_pts: LabelTensor | torch.Tensor
|
||||||
:return: The residual loss.
|
:return: The supervised loss, averaged over the number of observations.
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
return self._loss(self.forward(input_pts), output_pts)
|
return self._loss(self.forward(input_pts), output_pts)
|
||||||
@@ -122,6 +122,9 @@ class SupervisedSolver(SingleSolverInterface):
|
|||||||
@property
|
@property
|
||||||
def loss(self):
|
def loss(self):
|
||||||
"""
|
"""
|
||||||
Loss for training.
|
The loss function to be minimized.
|
||||||
|
|
||||||
|
:return: The loss function to be minimized.
|
||||||
|
:rtype: torch.nn.Module
|
||||||
"""
|
"""
|
||||||
return self._loss
|
return self._loss
|
||||||
|
|||||||
Reference in New Issue
Block a user