fix doc solver

This commit is contained in:
giovanni
2025-03-13 19:05:15 +01:00
committed by Nicola Demo
parent 3f8665b5d8
commit 5d908a291d
6 changed files with 374 additions and 251 deletions

View File

@@ -1,6 +1,4 @@
"""
TODO
"""
"""Module for the solver classes."""
__all__ = [
"SolverInterface",

View File

@@ -1,4 +1,4 @@
"""Module for GAROM"""
"""Module for the GAROM solver."""
import torch
from torch.nn.modules.loss import _Loss
@@ -10,9 +10,9 @@ from ..loss import LossInterface, PowerLoss
class GAROM(MultiSolverInterface):
"""
GAROM solver class. This class implements Generative Adversarial
Reduced Order Model solver, using user specified ``models`` to solve
a specific order reduction``problem``.
GAROM solver class. This class implements Generative Adversarial Reduced
Order Model solver, using user specified ``models`` to solve a specific
order reduction ``problem``.
.. seealso::
@@ -39,40 +39,28 @@ class GAROM(MultiSolverInterface):
regularizer=False,
):
"""
:param AbstractProblem problem: The formualation of the problem.
:param torch.nn.Module generator: The neural network model to use
for the generator.
:param torch.nn.Module discriminator: The neural network model to use
Initialization of the :class:`GAROM` class.
:param AbstractProblem problem: The formulation of the problem.
:param torch.nn.Module generator: The generator model.
:param torch.nn.Module discriminator: The discriminator model.
:param torch.nn.Module loss: The loss function to be minimized.
If ``None``, ``PowerLoss(p=1)`` is used. Default is ``None``.
:param Optimizer optimizer_generator: The optimizer for the generator.
If `None`, the Adam optimizer is used. Default is ``None``.
:param Optimizer optimizer_discriminator: The optimizer for the
discriminator. If `None`, the Adam optimizer is used.
Default is ``None``.
:param Scheduler scheduler_generator: The learning rate scheduler for
the generator.
:param Scheduler scheduler_discriminator: The learning rate scheduler
for the discriminator.
:param torch.nn.Module loss: The loss function used as minimizer,
default ``None``. If ``loss`` is ``None`` the defualt
``PowerLoss(p=1)`` is used, as in the original paper.
:param Optimizer optimizer_generator: The neural
network optimizer to use for the generator network
, default is `torch.optim.Adam`.
:param Optimizer optimizer_discriminator: The neural
network optimizer to use for the discriminator network
, default is `torch.optim.Adam`.
:param Scheduler scheduler_generator: Learning
rate scheduler for the generator.
:param Scheduler scheduler_discriminator: Learning
rate scheduler for the discriminator.
:param dict scheduler_discriminator_kwargs: LR scheduler constructor
keyword args.
:param gamma: Ratio of expected loss for generator and discriminator,
defaults to 0.3.
:type gamma: float
:param lambda_k: Learning rate for control theory optimization,
defaults to 0.001.
:type lambda_k: float
:param regularizer: Regularization term in the GAROM loss,
defaults to False.
:type regularizer: bool
.. warning::
The algorithm works only for data-driven model. Hence in the
``problem`` definition the codition must only contain ``input``
(e.g. coefficient parameters, time parameters), and ``target``.
:param float gamma: Ratio of expected loss for generator and
discriminator. Default is ``0.3``.
:param float lambda_k: Learning rate for control theory optimization.
Default is ``0.001``.
:param bool regularizer: If ``True``, uses a regularization term in the
GAROM loss. Default is ``False``.
"""
# set loss
@@ -112,19 +100,15 @@ class GAROM(MultiSolverInterface):
def forward(self, x, mc_steps=20, variance=False):
"""
Forward step for GAROM solver
Forward pass implementation.
:param x: The input tensor.
:type x: torch.Tensor
:param mc_steps: Number of montecarlo samples to approximate the
expected value, defaults to 20.
:type mc_steps: int
:param variance: Returining also the sample variance of the solution,
defaults to False.
:type variance: bool
:param torch.Tensor x: The input tensor.
:param int mc_steps: Number of Montecarlo samples to approximate the
expected value. Default is ``20``.
:param bool variance: If ``True``, the method returns also the variance
of the solution. Default is ``False``.
:return: The expected value of the generator distribution. If
``variance=True`` also the
sample variance is returned.
``variance=True``, the method returns also the variance.
:rtype: torch.Tensor | tuple(torch.Tensor, torch.Tensor)
"""
@@ -142,13 +126,24 @@ class GAROM(MultiSolverInterface):
return mean
def sample(self, x):
"""TODO"""
"""
Sample from the generator distribution.
:param torch.Tensor x: The input tensor.
:return: The generated sample.
:rtype: torch.Tensor
"""
# sampling
return self.generator(x)
def _train_generator(self, parameters, snapshots):
"""
Private method to train the generator network.
Train the generator model.
:param torch.Tensor parameters: The input tensor.
:param torch.Tensor snapshots: The target tensor.
:return: The residual loss and the generator loss.
:rtype: tuple(torch.Tensor, torch.Tensor)
"""
optimizer = self.optimizer_generator
optimizer.zero_grad()
@@ -170,16 +165,13 @@ class GAROM(MultiSolverInterface):
def on_train_batch_end(self, outputs, batch, batch_idx):
"""
This method is called at the end of each training batch, and ovverides
the PytorchLightining implementation for logging the checkpoints.
This method is called at the end of each training batch and overrides
the PyTorch Lightning implementation to log checkpoints.
:param torch.Tensor outputs: The output from the model for the
current batch.
:param tuple batch: The current batch of data.
:param torch.Tensor outputs: The ``model``'s output for the current
batch.
:param dict batch: The current batch of data.
:param int batch_idx: The index of the current batch.
:return: Whatever is returned by the parent
method ``on_train_batch_end``.
:rtype: Any
"""
# increase by one the counter of optimization to save loggers
(
@@ -190,7 +182,12 @@ class GAROM(MultiSolverInterface):
def _train_discriminator(self, parameters, snapshots):
"""
Private method to train the discriminator network.
Train the discriminator model.
:param torch.Tensor parameters: The input tensor.
:param torch.Tensor snapshots: The target tensor.
:return: The residual loss and the generator loss.
:rtype: tuple(torch.Tensor, torch.Tensor)
"""
optimizer = self.optimizer_discriminator
optimizer.zero_grad()
@@ -215,8 +212,15 @@ class GAROM(MultiSolverInterface):
def _update_weights(self, d_loss_real, d_loss_fake):
"""
Private method to Update the weights of the generator and discriminator
networks.
Update the weights of the generator and discriminator models.
:param torch.Tensor d_loss_real: The discriminator loss computed on
dataset samples.
:param torch.Tensor d_loss_fake: The discriminator loss computed on
generated samples.
:return: The difference between the loss computed on the dataset samples
and the loss computed on the generated samples.
:rtype: torch.Tensor
"""
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
@@ -227,11 +231,11 @@ class GAROM(MultiSolverInterface):
return diff
def optimization_cycle(self, batch):
"""GAROM solver training step.
"""
The optimization cycle for the GAROM solver.
:param batch: The batch element in the dataloader.
:type batch: tuple
:return: The sum of the loss functions.
:param tuple batch: The batch element in the dataloader.
:return: The loss of the optimization cycle.
:rtype: LabelTensor
"""
condition_loss = {}
@@ -258,6 +262,13 @@ class GAROM(MultiSolverInterface):
return condition_loss
def validation_step(self, batch):
"""
The validation step for the PINN solver.
:param dict batch: The batch of data to use in the validation step.
:return: The loss of the validation step.
:rtype: torch.Tensor
"""
condition_loss = {}
for condition_name, points in batch:
parameters, snapshots = (
@@ -273,6 +284,13 @@ class GAROM(MultiSolverInterface):
return loss
def test_step(self, batch):
"""
The test step for the PINN solver.
:param dict batch: The batch of data to use in the test step.
:return: The loss of the test step.
:rtype: torch.Tensor
"""
condition_loss = {}
for condition_name, points in batch:
parameters, snapshots = (
@@ -289,30 +307,60 @@ class GAROM(MultiSolverInterface):
@property
def generator(self):
"""TODO"""
"""
The generator model.
:return: The generator model.
:rtype: torch.nn.Module
"""
return self.models[0]
@property
def discriminator(self):
"""TODO"""
"""
The discriminator model.
:return: The discriminator model.
:rtype: torch.nn.Module
"""
return self.models[1]
@property
def optimizer_generator(self):
"""TODO"""
"""
The optimizer for the generator.
:return: The optimizer for the generator.
:rtype: Optimizer
"""
return self.optimizers[0].instance
@property
def optimizer_discriminator(self):
"""TODO"""
"""
The optimizer for the discriminator.
:return: The optimizer for the discriminator.
:rtype: Optimizer
"""
return self.optimizers[1].instance
@property
def scheduler_generator(self):
"""TODO"""
"""
The scheduler for the generator.
:return: The scheduler for the generator.
:rtype: Scheduler
"""
return self.schedulers[0].instance
@property
def scheduler_discriminator(self):
"""TODO"""
"""
The scheduler for the discriminator.
:return: The scheduler for the discriminator.
:rtype: Scheduler
"""
return self.schedulers[1].instance

View File

@@ -110,8 +110,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
def loss_data(self, input_pts, output_pts):
"""
Compute the data loss for the PINN solver by evaluating the loss
between the network's output and the true solution. This method
should only be overridden intentionally.
between the network's output and the true solution. This method should
not be overridden, if not intentionally.
:param LabelTensor input_pts: The input points to the neural network.
:param LabelTensor output_pts: The true solution to compare with the

View File

@@ -1,19 +1,17 @@
"""Module for ReducedOrderModelSolver"""
import torch
from .supervised import SupervisedSolver
class ReducedOrderModelSolver(SupervisedSolver):
r"""
ReducedOrderModelSolver solver class. This class implements a
Reduced Order Model solver, using user specified ``reduction_network`` and
Reduced Order Model solver class. This class implements the Reduced Order
Model solver, using user specified ``reduction_network`` and
``interpolation_network`` to solve a specific ``problem``.
The Reduced Order Model approach aims to find
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`
of the differential problem:
The Reduced Order Model solver aims to find the solution
:math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` of a differential problem:
.. math::
@@ -23,13 +21,13 @@ class ReducedOrderModelSolver(SupervisedSolver):
\mathbf{x}\in\partial\Omega
\end{cases}
This is done by using two neural networks. The ``reduction_network``, which
contains an encoder :math:`\mathcal{E}_{\rm{net}}`, a decoder
:math:`\mathcal{D}_{\rm{net}}`; and an ``interpolation_network``
This is done by means of two neural networks: the ``reduction_network``,
which defines an encoder :math:`\mathcal{E}_{\rm{net}}`, and a decoder
:math:`\mathcal{D}_{\rm{net}}`; and the ``interpolation_network``
:math:`\mathcal{I}_{\rm{net}}`. The input is assumed to be discretised in
the spatial dimensions.
The following loss function is minimized during training
The following loss function is minimized during training:
.. math::
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
@@ -39,49 +37,46 @@ class ReducedOrderModelSolver(SupervisedSolver):
\mathcal{D}_{\rm{net}}[\mathcal{E}_{\rm{net}}[\mathbf{u}(\mu_i)]] -
\mathbf{u}(\mu_i))
where :math:`\mathcal{L}` is a specific loss function, default
Mean Square Error:
where :math:`\mathcal{L}` is a specific loss function, typically the MSE:
.. math::
\mathcal{L}(v) = \| v \|^2_2.
.. seealso::
**Original reference**: Hesthaven, Jan S., and Stefano Ubbiali.
"Non-intrusive reduced order modeling of nonlinear problems
using neural networks." Journal of Computational
Physics 363 (2018): 55-78.
"Non-intrusive reduced order modeling of nonlinear problems using
neural networks."
Journal of Computational Physics 363 (2018): 55-78.
DOI `10.1016/j.jcp.2018.02.037
<https://doi.org/10.1016/j.jcp.2018.02.037>`_.
.. note::
The specified ``reduction_network`` must contain two methods,
namely ``encode`` for input encoding and ``decode`` for decoding the
former result. The ``interpolation_network`` network ``forward`` output
represents the interpolation of the latent space obtain with
The specified ``reduction_network`` must contain two methods, namely
``encode`` for input encoding, and ``decode`` for decoding the former
result. The ``interpolation_network`` network ``forward`` output
represents the interpolation of the latent space obtained with
``reduction_network.encode``.
.. note::
This solver uses the end-to-end training strategy, i.e. the
``reduction_network`` and ``interpolation_network`` are trained
simultaneously. For reference on this trainig strategy look at:
Pichi, Federico, Beatriz Moya, and Jan S. Hesthaven.
simultaneously. For reference on this trainig strategy look at the
following:
..seealso::
**Original reference**: Pichi, Federico, Beatriz Moya, and Jan S.
Hesthaven.
"A graph convolutional autoencoder approach to model order reduction
for parametrized PDEs." Journal of
Computational Physics 501 (2024): 112762.
DOI
`10.1016/j.jcp.2024.112762 <https://doi.org/10.1016/
j.jcp.2024.112762>`_.
for parametrized PDEs."
Journal of Computational Physics 501 (2024): 112762.
DOI `10.1016/j.jcp.2024.112762
<https://doi.org/10.1016/j.jcp.2024.112762>`_.
.. warning::
This solver works only for data-driven model. Hence in the ``problem``
definition the codition must only contain ``input``
(e.g. coefficient parameters, time parameters), and ``target``.
.. warning::
This solver does not currently support the possibility to pass
``extra_feature``.
"""
def __init__(
@@ -96,22 +91,28 @@ class ReducedOrderModelSolver(SupervisedSolver):
use_lt=True,
):
"""
Initialization of the :class:`ReducedOrderModelSolver` class.
:param AbstractProblem problem: The formualation of the problem.
:param torch.nn.Module reduction_network: The reduction network used
for reducing the input space. It must contain two methods,
namely ``encode`` for input encoding and ``decode`` for decoding the
for reducing the input space. It must contain two methods, namely
``encode`` for input encoding, and ``decode`` for decoding the
former result.
:param torch.nn.Module interpolation_network: The interpolation network
for interpolating the control parameters to latent space obtain by
for interpolating the control parameters to latent space obtained by
the ``reduction_network`` encoding.
:param torch.nn.Module loss: The loss function used as minimizer,
default :class:`torch.nn.MSELoss`.
:param torch.optim.Optimizer optimizer: The neural network optimizer to
use; default is :class:`torch.optim.Adam`.
:param torch.optim.LRScheduler scheduler: Learning
rate scheduler.
:param WeightingInterface weighting: The loss weighting to use.
:param bool use_lt: Using LabelTensors as input during training.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the :class:`torch.nn.MSELoss` loss is used.
Default is `None`.
:param Optimizer optimizer: The optimizer to be used.
If `None`, the :class:`torch.optim.Adam`. optimizer is used.
Default is ``None``.
:param Scheduler scheduler: Learning rate scheduler. If `None`,
the constant learning rate scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
Default is ``True``.
"""
model = torch.nn.ModuleDict(
{
@@ -146,10 +147,10 @@ class ReducedOrderModelSolver(SupervisedSolver):
def forward(self, x):
"""
Forward pass implementation for the solver. It finds the encoder
representation by calling ``interpolation_network.forward`` on the
input, and maps this representation to output space by calling
``reduction_network.decode``.
Forward pass implementation.
It computes the encoder representation by calling the forward method
of the ``interpolation_network`` on the input, and maps it to output
space by calling the decode methode of the ``reduction_network``.
:param torch.Tensor x: Input tensor.
:return: Solver solution.
@@ -161,15 +162,14 @@ class ReducedOrderModelSolver(SupervisedSolver):
def loss_data(self, input_pts, output_pts):
"""
The data loss for the ReducedOrderModelSolver solver.
It computes the loss between
the network output against the true solution. This function
should not be override if not intentionally.
Compute the data loss by evaluating the loss between the network's
output and the true solution. This method should not be overridden, if
not intentionally.
:param LabelTensor input_tensor: The input to the neural networks.
:param LabelTensor output_tensor: The true solution to compare the
network solution.
:return: The residual loss averaged on the input coordinates
:param LabelTensor input_pts: The input points to the neural network.
:param LabelTensor output_pts: The true solution to compare with the
network's output.
:return: The supervised loss, averaged over the number of observations.
:rtype: torch.Tensor
"""
# extract networks

View File

@@ -14,17 +14,19 @@ from ..utils import check_consistency, labelize_forward
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
SolverInterface base class. This class is a wrapper of LightningModule.
Abstract base class for PINA solvers. All specific solvers should inherit
from this interface. This class is a wrapper of
:class:`~lightning.pytorch.LightningModule`.
"""
def __init__(self, problem, weighting, use_lt):
"""
:param problem: A problem definition instance.
:type problem: AbstractProblem
:param weighting: The loss weighting to use.
:type weighting: WeightingInterface
:param use_lt: Using LabelTensors as input during training.
:type use_lt: bool
Initialization of the :class:`SolverInterface` class.
:param AbstractProblem problem: The problem to be solved.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
"""
super().__init__()
@@ -59,22 +61,24 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
self._pina_schedulers = None
def _check_solver_consistency(self, problem):
"""
Check the consistency of the solver with the problem formulation.
:param AbstractProblem problem: The problem to be solved.
"""
for condition in problem.conditions.values():
check_consistency(condition, self.accepted_conditions_types)
def _optimization_cycle(self, batch):
"""
Perform a private optimization cycle by computing the loss for each
condition in the given batch. The loss are later aggregated using the
specific weighting schema.
Aggregate the loss for each condition in the batch.
:param batch: A batch of data, where each element is a tuple containing
a condition name and a dictionary of points.
:type batch: list of tuples (str, dict)
:return: The computed loss for the all conditions in the batch,
cast to a subclass of `torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict(torch.Tensor)
:param list[tuple[str, dict]] batch: A batch of data. Each element is a
tuple containing a condition name and a dictionary of points.
:return: The computed loss for the all conditions in the batch, casted
to a subclass of `torch.Tensor`. It should return a dict containing
the condition name and the associated scalar loss.
:rtype: dict
"""
losses = self.optimization_cycle(batch)
for name, value in losses.items():
@@ -88,9 +92,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
Solver training step.
:param batch: The batch element in the dataloader.
:type batch: tuple
:return: The sum of the loss functions.
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
:return: The loss of the training step.
:rtype: LabelTensor
"""
loss = self._optimization_cycle(batch=batch)
@@ -101,8 +104,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
Solver validation step.
:param batch: The batch element in the dataloader.
:type batch: tuple
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
"""
loss = self._optimization_cycle(batch=batch)
self.store_log("val_loss", loss, self.get_batch_size(batch))
@@ -111,15 +113,18 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
"""
Solver test step.
:param batch: The batch element in the dataloader.
:type batch: tuple
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
"""
loss = self._optimization_cycle(batch=batch)
self.store_log("test_loss", loss, self.get_batch_size(batch))
def store_log(self, name, value, batch_size):
"""
TODO
Store the log of the solver.
:param str name: The name of the log.
:param torch.Tensor value: The value of the log.
:param int batch_size: The size of the batch.
"""
self.log(
@@ -132,49 +137,59 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
@abstractmethod
def forward(self, *args, **kwargs):
"""
TODO
Abstract method for the forward pass implementation.
"""
@abstractmethod
def optimization_cycle(self, batch):
"""
Perform an optimization cycle by computing the loss for each condition
in the given batch.
The optimization cycle for the solvers.
:param batch: A batch of data, where each element is a tuple containing
a condition name and a dictionary of points.
:type batch: list of tuples (str, dict)
:return: The computed loss for the all conditions in the batch,
cast to a subclass of `torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict(torch.Tensor)
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
:return: The computed loss for the all conditions in the batch, casted
to a subclass of `torch.Tensor`. It should return a dict containing
the condition name and the associated scalar loss.
:rtype: dict
"""
@property
def problem(self):
"""
The problem formulation.
The problem instance.
:return: The problem instance.
:rtype: :class:`~pina.problem.abstract_problem.AbstractProblem`
"""
return self._pina_problem
@property
def use_lt(self):
"""
Using LabelTensor in training.
Using LabelTensors as input during training.
:return: The use_lt attribute.
:rtype: bool
"""
return self._use_lt
@property
def weighting(self):
"""
The weighting mechanism.
The weighting schema.
:return: The weighting schema.
:rtype: :class:`~pina.loss.weighting_interface.WeightingInterface`
"""
return self._pina_weighting
@staticmethod
def get_batch_size(batch):
"""
TODO
Get the batch size.
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
:return: The size of the batch.
:rtype: int
"""
batch_size = 0
@@ -185,23 +200,29 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
@staticmethod
def default_torch_optimizer():
"""
TODO
"""
Set the default optimizer to :class:`torch.optim.Adam`.
:return: The default optimizer.
:rtype: Optimizer
"""
return TorchOptimizer(torch.optim.Adam, lr=0.001)
@staticmethod
def default_torch_scheduler():
"""
TODO
Set the default scheduler to
:class:`torch.optim.lr_scheduler.ConstantLR`.
:return: The default scheduler.
:rtype: Scheduler
"""
return TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
def on_train_start(self):
"""
Hook that is called before training begins.
Used to compile the model if the trainer is set to compile.
This method is called at the start of the training process to compile
the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``.
"""
super().on_train_start()
if self.trainer.compile:
@@ -209,8 +230,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
def on_test_start(self):
"""
Hook that is called before training begins.
Used to compile the model if the trainer is set to compile.
This method is called at the start of the test process to compile
the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``.
"""
super().on_train_start()
if self.trainer.compile and not self._check_already_compiled():
@@ -218,7 +239,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
def _check_already_compiled(self):
"""
TODO
Check if the model is already compiled.
:return: ``True`` if the model is already compiled, ``False`` otherwise.
:rtype: bool
"""
models = self._pina_models
@@ -234,7 +258,12 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
@staticmethod
def _perform_compilation(model):
"""
TODO
Perform the compilation of the model.
:param torch.nn.Module model: The model to compile.
:raises Exception: If the compilation fails.
:return: The compiled model.
:rtype: torch.nn.Module
"""
model_device = next(model.parameters()).device
@@ -249,8 +278,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
"""TODO"""
"""
Base class for PINA solvers using a single :class:`torch.nn.Module`.
"""
def __init__(
self,
problem,
@@ -261,14 +291,18 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
use_lt=True,
):
"""
:param problem: A problem definition instance.
:type problem: AbstractProblem
:param model: A torch nn.Module instances.
:type model: torch.nn.Module
:param Optimizer optimizers: A neural network optimizers to use.
:param Scheduler optimizers: A neural network scheduler to use.
:param WeightingInterface weighting: The loss weighting to use.
:param bool use_lt: Using LabelTensors as input during training.
Initialization of the :class:`SingleSolverInterface` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module model: The neural network model to be used.
:param Optimizer optimizer: The optimizer to be used.
If `None`, the Adam optimizer is used. Default is ``None``.
:param Scheduler scheduler: The scheduler to be used.
If `None`, the constant learning rate scheduler is used.
Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
"""
if optimizer is None:
optimizer = self.default_torch_optimizer()
@@ -292,11 +326,12 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
def forward(self, x):
"""
Forward pass implementation for the solver.
Forward pass implementation.
:param torch.Tensor x: Input tensor.
:param x: Input tensor.
:type x: torch.Tensor | LabelTensor
:return: Solver solution.
:rtype: torch.Tensor
:rtype: torch.Tensor | LabelTensor
"""
x = self.model(x)
return x
@@ -305,7 +340,7 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
"""
Optimizer configuration for the solver.
:return: The optimizers and the schedulers
:return: The optimizer and the scheduler
:rtype: tuple(list, list)
"""
self.optimizer.hook(self.model.parameters())
@@ -313,44 +348,61 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
return ([self.optimizer.instance], [self.scheduler.instance])
def _compile_model(self):
"""
Compile the model.
"""
if isinstance(self._pina_models[0], torch.nn.ModuleDict):
self._compile_module_dict()
else:
self._compile_single_model()
def _compile_module_dict(self):
"""
Compile the model if it is a :class:`torch.nn.ModuleDict`.
"""
for name, model in self._pina_models[0].items():
self._pina_models[0][name] = self._perform_compilation(model)
def _compile_single_model(self):
"""
Compile the model if it is a single :class:`torch.nn.Module`.
"""
self._pina_models[0] = self._perform_compilation(self._pina_models[0])
@property
def model(self):
"""
Model for training.
The model used for training.
:return: The model used for training.
:rtype: torch.nn.Module
"""
return self._pina_models[0]
@property
def scheduler(self):
"""
Scheduler for training.
The scheduler used for training.
:return: The scheduler used for training.
:rtype: Scheduler
"""
return self._pina_schedulers[0]
@property
def optimizer(self):
"""
Optimizer for training.
The optimizer used for training.
:return: The optimizer used for training.
:rtype: Optimizer
"""
return self._pina_optimizers[0]
class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
"""
Multiple Solver base class. This class inherits is a wrapper of
SolverInterface class
Base class for PINA solvers using multiple :class:`torch.nn.Module`.
"""
def __init__(
@@ -363,16 +415,22 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
use_lt=True,
):
"""
:param problem: A problem definition instance.
:type problem: AbstractProblem
:param models: Multiple torch nn.Module instances.
Initialization of the :class:`MultiSolverInterface` class.
:param AbstractProblem problem: The problem to be solved.
:param models: The neural network models to be used.
:type model: list[torch.nn.Module] | tuple[torch.nn.Module]
:param list(Optimizer) optimizers: A list of neural network
optimizers to use.
:param list(Scheduler) optimizers: A list of neural network
schedulers to use.
:param WeightingInterface weighting: The loss weighting to use.
:param bool use_lt: Using LabelTensors as input during training.
:param list[Optimizer] optimizers: The optimizers to be used.
If `None`, the Adam optimizer is used for all models.
Default is ``None``.
:param list[Scheduler] schedulers: The schedulers to be used.
If `None`, the constant learning rate scheduler is used for all the
models. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
:raises ValueError: If the models are not a list or tuple with length
greater than one.
"""
if not isinstance(models, (list, tuple)) or len(models) < 2:
raise ValueError(
@@ -418,9 +476,10 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
self._pina_schedulers = schedulers
def configure_optimizers(self):
"""Optimizer configuration for the solver.
"""
Optimizer configuration for the solver.
:return: The optimizers and the schedulers
:return: The optimizer and the scheduler
:rtype: tuple(list, list)
"""
for optimizer, scheduler, model in zip(
@@ -435,6 +494,9 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
)
def _compile_model(self):
"""
Compile the model.
"""
for i, model in enumerate(self._pina_models):
if not isinstance(model, torch.nn.ModuleDict):
self._pina_models[i] = self._perform_compilation(model)
@@ -442,17 +504,29 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
@property
def models(self):
"""
The torch model."""
The models used for training.
:return: The models used for training.
:rtype: torch.nn.ModuleList
"""
return self._pina_models
@property
def optimizers(self):
"""
The torch model."""
The optimizers used for training.
:return: The optimizers used for training.
:rtype: list[Optimizer]
"""
return self._pina_optimizers
@property
def schedulers(self):
"""
The torch model."""
The schedulers used for training.
:return: The schedulers used for training.
:rtype: list[Scheduler]
"""
return self._pina_schedulers

View File

@@ -1,4 +1,4 @@
"""Module for SupervisedSolver"""
"""Module for the Supervised Solver."""
import torch
from torch.nn.modules.loss import _Loss
@@ -10,31 +10,28 @@ from ..condition import InputTargetCondition
class SupervisedSolver(SingleSolverInterface):
r"""
SupervisedSolver solver class. This class implements a SupervisedSolver,
Supervised Solver solver class. This class implements a Supervised Solver,
using a user specified ``model`` to solve a specific ``problem``.
The Supervised Solver class aims to find
a map between the input :math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m`
and the output :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`. The input
can be discretised in space (as in :obj:`~pina.solver.rom.ROMe2eSolver`),
or not (e.g. when training Neural Operators).
The Supervised Solver class aims to find a map between the input
:math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m` and the output
:math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`.
Given a model :math:`\mathcal{M}`, the following loss function is
minimized during training:
.. math::
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
\mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{v}_i))
\mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{v}_i)),
where :math:`\mathcal{L}` is a specific loss function,
default Mean Square Error:
where :math:`\mathcal{L}` is a specific loss function, typically the MSE:
.. math::
\mathcal{L}(v) = \| v \|^2_2.
In this context :math:`\mathbf{u}_i` and :math:`\mathbf{v}_i` means that
we are seeking to approximate multiple (discretised) functions given
multiple (discretised) input functions.
In this context, :math:`\mathbf{u}_i` and :math:`\mathbf{v}_i` indicates
the will to approximate multiple (discretised) functions given multiple
(discretised) input functions.
"""
accepted_conditions_types = InputTargetCondition
@@ -50,16 +47,22 @@ class SupervisedSolver(SingleSolverInterface):
use_lt=True,
):
"""
:param AbstractProblem problem: The formualation of the problem.
:param torch.nn.Module model: The neural network model to use.
:param torch.nn.Module loss: The loss function used as minimizer,
default :class:`torch.nn.MSELoss`.
:param torch.optim.Optimizer optimizer: The neural network optimizer to
use; default is :class:`torch.optim.Adam`.
:param torch.optim.LRScheduler scheduler: Learning
rate scheduler.
:param WeightingInterface weighting: The loss weighting to use.
:param bool use_lt: Using LabelTensors as input during training.
Initialization of the :class:`SupervisedSolver` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module model: The neural network model to be used.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the Mean Squared Error (MSE) loss is used.
Default is `None`.
:param torch.optim.Optimizer optimizer: The optimizer to be used.
If `None`, the Adam optimizer is used. Default is ``None``.
:param torch.optim.LRScheduler scheduler: Learning rate scheduler.
If `None`, the constant learning rate scheduler is used.
Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
Default is ``True``.
"""
if loss is None:
loss = torch.nn.MSELoss()
@@ -81,16 +84,13 @@ class SupervisedSolver(SingleSolverInterface):
def optimization_cycle(self, batch):
"""
Perform an optimization cycle by computing the loss for each condition
in the given batch.
The optimization cycle for the solvers.
:param batch: A batch of data, where each element is a tuple containing
a condition name and a dictionary of points.
:type batch: list of tuples (str, dict)
:return: The computed loss for the all conditions in the batch,
cast to a subclass of `torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict(torch.Tensor)
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
:return: The computed loss for the all conditions in the batch, casted
to a subclass of `torch.Tensor`. It should return a dict containing
the condition name and the associated scalar loss.
:rtype: dict
"""
condition_loss = {}
for condition_name, points in batch:
@@ -105,16 +105,16 @@ class SupervisedSolver(SingleSolverInterface):
def loss_data(self, input_pts, output_pts):
"""
The data loss for the Supervised solver. It computes the loss between
the network output against the true solution. This function
should not be override if not intentionally.
Compute the data loss for the Supervised solver by evaluating the loss
between the network's output and the true solution. This method should
not be overridden, if not intentionally.
:param input_pts: The input to the neural networks.
:param input_pts: The input points to the neural network.
:type input_pts: LabelTensor | torch.Tensor
:param output_pts: The true solution to compare the
network solution.
:param output_pts: The true solution to compare with the network's
output.
:type output_pts: LabelTensor | torch.Tensor
:return: The residual loss.
:return: The supervised loss, averaged over the number of observations.
:rtype: torch.Tensor
"""
return self._loss(self.forward(input_pts), output_pts)
@@ -122,6 +122,9 @@ class SupervisedSolver(SingleSolverInterface):
@property
def loss(self):
"""
Loss for training.
The loss function to be minimized.
:return: The loss function to be minimized.
:rtype: torch.nn.Module
"""
return self._loss