fix doc solver

This commit is contained in:
giovanni
2025-03-13 19:05:15 +01:00
committed by FilippoOlivo
parent fd2b50fc06
commit 5e6aa61592
6 changed files with 374 additions and 251 deletions

View File

@@ -1,6 +1,4 @@
""" """Module for the solver classes."""
TODO
"""
__all__ = [ __all__ = [
"SolverInterface", "SolverInterface",

View File

@@ -1,4 +1,4 @@
"""Module for GAROM""" """Module for the GAROM solver."""
import torch import torch
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
@@ -10,9 +10,9 @@ from ..loss import LossInterface, PowerLoss
class GAROM(MultiSolverInterface): class GAROM(MultiSolverInterface):
""" """
GAROM solver class. This class implements Generative Adversarial GAROM solver class. This class implements Generative Adversarial Reduced
Reduced Order Model solver, using user specified ``models`` to solve Order Model solver, using user specified ``models`` to solve a specific
a specific order reduction``problem``. order reduction ``problem``.
.. seealso:: .. seealso::
@@ -39,40 +39,28 @@ class GAROM(MultiSolverInterface):
regularizer=False, regularizer=False,
): ):
""" """
:param AbstractProblem problem: The formualation of the problem. Initialization of the :class:`GAROM` class.
:param torch.nn.Module generator: The neural network model to use
for the generator.
:param torch.nn.Module discriminator: The neural network model to use
for the discriminator.
:param torch.nn.Module loss: The loss function used as minimizer,
default ``None``. If ``loss`` is ``None`` the defualt
``PowerLoss(p=1)`` is used, as in the original paper.
:param Optimizer optimizer_generator: The neural
network optimizer to use for the generator network
, default is `torch.optim.Adam`.
:param Optimizer optimizer_discriminator: The neural
network optimizer to use for the discriminator network
, default is `torch.optim.Adam`.
:param Scheduler scheduler_generator: Learning
rate scheduler for the generator.
:param Scheduler scheduler_discriminator: Learning
rate scheduler for the discriminator.
:param dict scheduler_discriminator_kwargs: LR scheduler constructor
keyword args.
:param gamma: Ratio of expected loss for generator and discriminator,
defaults to 0.3.
:type gamma: float
:param lambda_k: Learning rate for control theory optimization,
defaults to 0.001.
:type lambda_k: float
:param regularizer: Regularization term in the GAROM loss,
defaults to False.
:type regularizer: bool
.. warning:: :param AbstractProblem problem: The formulation of the problem.
The algorithm works only for data-driven model. Hence in the :param torch.nn.Module generator: The generator model.
``problem`` definition the codition must only contain ``input`` :param torch.nn.Module discriminator: The discriminator model.
(e.g. coefficient parameters, time parameters), and ``target``. :param torch.nn.Module loss: The loss function to be minimized.
If ``None``, ``PowerLoss(p=1)`` is used. Default is ``None``.
:param Optimizer optimizer_generator: The optimizer for the generator.
If `None`, the Adam optimizer is used. Default is ``None``.
:param Optimizer optimizer_discriminator: The optimizer for the
discriminator. If `None`, the Adam optimizer is used.
Default is ``None``.
:param Scheduler scheduler_generator: The learning rate scheduler for
the generator.
:param Scheduler scheduler_discriminator: The learning rate scheduler
for the discriminator.
:param float gamma: Ratio of expected loss for generator and
discriminator. Default is ``0.3``.
:param float lambda_k: Learning rate for control theory optimization.
Default is ``0.001``.
:param bool regularizer: If ``True``, uses a regularization term in the
GAROM loss. Default is ``False``.
""" """
# set loss # set loss
@@ -112,19 +100,15 @@ class GAROM(MultiSolverInterface):
def forward(self, x, mc_steps=20, variance=False): def forward(self, x, mc_steps=20, variance=False):
""" """
Forward step for GAROM solver Forward pass implementation.
:param x: The input tensor. :param torch.Tensor x: The input tensor.
:type x: torch.Tensor :param int mc_steps: Number of Montecarlo samples to approximate the
:param mc_steps: Number of montecarlo samples to approximate the expected value. Default is ``20``.
expected value, defaults to 20. :param bool variance: If ``True``, the method returns also the variance
:type mc_steps: int of the solution. Default is ``False``.
:param variance: Returining also the sample variance of the solution,
defaults to False.
:type variance: bool
:return: The expected value of the generator distribution. If :return: The expected value of the generator distribution. If
``variance=True`` also the ``variance=True``, the method returns also the variance.
sample variance is returned.
:rtype: torch.Tensor | tuple(torch.Tensor, torch.Tensor) :rtype: torch.Tensor | tuple(torch.Tensor, torch.Tensor)
""" """
@@ -142,13 +126,24 @@ class GAROM(MultiSolverInterface):
return mean return mean
def sample(self, x): def sample(self, x):
"""TODO""" """
Sample from the generator distribution.
:param torch.Tensor x: The input tensor.
:return: The generated sample.
:rtype: torch.Tensor
"""
# sampling # sampling
return self.generator(x) return self.generator(x)
def _train_generator(self, parameters, snapshots): def _train_generator(self, parameters, snapshots):
""" """
Private method to train the generator network. Train the generator model.
:param torch.Tensor parameters: The input tensor.
:param torch.Tensor snapshots: The target tensor.
:return: The residual loss and the generator loss.
:rtype: tuple(torch.Tensor, torch.Tensor)
""" """
optimizer = self.optimizer_generator optimizer = self.optimizer_generator
optimizer.zero_grad() optimizer.zero_grad()
@@ -170,16 +165,13 @@ class GAROM(MultiSolverInterface):
def on_train_batch_end(self, outputs, batch, batch_idx): def on_train_batch_end(self, outputs, batch, batch_idx):
""" """
This method is called at the end of each training batch, and ovverides This method is called at the end of each training batch and overrides
the PytorchLightining implementation for logging the checkpoints. the PyTorch Lightning implementation to log checkpoints.
:param torch.Tensor outputs: The output from the model for the :param torch.Tensor outputs: The ``model``'s output for the current
current batch. batch.
:param tuple batch: The current batch of data. :param dict batch: The current batch of data.
:param int batch_idx: The index of the current batch. :param int batch_idx: The index of the current batch.
:return: Whatever is returned by the parent
method ``on_train_batch_end``.
:rtype: Any
""" """
# increase by one the counter of optimization to save loggers # increase by one the counter of optimization to save loggers
( (
@@ -190,7 +182,12 @@ class GAROM(MultiSolverInterface):
def _train_discriminator(self, parameters, snapshots): def _train_discriminator(self, parameters, snapshots):
""" """
Private method to train the discriminator network. Train the discriminator model.
:param torch.Tensor parameters: The input tensor.
:param torch.Tensor snapshots: The target tensor.
:return: The residual loss and the generator loss.
:rtype: tuple(torch.Tensor, torch.Tensor)
""" """
optimizer = self.optimizer_discriminator optimizer = self.optimizer_discriminator
optimizer.zero_grad() optimizer.zero_grad()
@@ -215,8 +212,15 @@ class GAROM(MultiSolverInterface):
def _update_weights(self, d_loss_real, d_loss_fake): def _update_weights(self, d_loss_real, d_loss_fake):
""" """
Private method to Update the weights of the generator and discriminator Update the weights of the generator and discriminator models.
networks.
:param torch.Tensor d_loss_real: The discriminator loss computed on
dataset samples.
:param torch.Tensor d_loss_fake: The discriminator loss computed on
generated samples.
:return: The difference between the loss computed on the dataset samples
and the loss computed on the generated samples.
:rtype: torch.Tensor
""" """
diff = torch.mean(self.gamma * d_loss_real - d_loss_fake) diff = torch.mean(self.gamma * d_loss_real - d_loss_fake)
@@ -227,11 +231,11 @@ class GAROM(MultiSolverInterface):
return diff return diff
def optimization_cycle(self, batch): def optimization_cycle(self, batch):
"""GAROM solver training step. """
The optimization cycle for the GAROM solver.
:param batch: The batch element in the dataloader. :param tuple batch: The batch element in the dataloader.
:type batch: tuple :return: The loss of the optimization cycle.
:return: The sum of the loss functions.
:rtype: LabelTensor :rtype: LabelTensor
""" """
condition_loss = {} condition_loss = {}
@@ -258,6 +262,13 @@ class GAROM(MultiSolverInterface):
return condition_loss return condition_loss
def validation_step(self, batch): def validation_step(self, batch):
"""
The validation step for the PINN solver.
:param dict batch: The batch of data to use in the validation step.
:return: The loss of the validation step.
:rtype: torch.Tensor
"""
condition_loss = {} condition_loss = {}
for condition_name, points in batch: for condition_name, points in batch:
parameters, snapshots = ( parameters, snapshots = (
@@ -273,6 +284,13 @@ class GAROM(MultiSolverInterface):
return loss return loss
def test_step(self, batch): def test_step(self, batch):
"""
The test step for the PINN solver.
:param dict batch: The batch of data to use in the test step.
:return: The loss of the test step.
:rtype: torch.Tensor
"""
condition_loss = {} condition_loss = {}
for condition_name, points in batch: for condition_name, points in batch:
parameters, snapshots = ( parameters, snapshots = (
@@ -289,30 +307,60 @@ class GAROM(MultiSolverInterface):
@property @property
def generator(self): def generator(self):
"""TODO""" """
The generator model.
:return: The generator model.
:rtype: torch.nn.Module
"""
return self.models[0] return self.models[0]
@property @property
def discriminator(self): def discriminator(self):
"""TODO""" """
The discriminator model.
:return: The discriminator model.
:rtype: torch.nn.Module
"""
return self.models[1] return self.models[1]
@property @property
def optimizer_generator(self): def optimizer_generator(self):
"""TODO""" """
The optimizer for the generator.
:return: The optimizer for the generator.
:rtype: Optimizer
"""
return self.optimizers[0].instance return self.optimizers[0].instance
@property @property
def optimizer_discriminator(self): def optimizer_discriminator(self):
"""TODO""" """
The optimizer for the discriminator.
:return: The optimizer for the discriminator.
:rtype: Optimizer
"""
return self.optimizers[1].instance return self.optimizers[1].instance
@property @property
def scheduler_generator(self): def scheduler_generator(self):
"""TODO""" """
The scheduler for the generator.
:return: The scheduler for the generator.
:rtype: Scheduler
"""
return self.schedulers[0].instance return self.schedulers[0].instance
@property @property
def scheduler_discriminator(self): def scheduler_discriminator(self):
"""TODO""" """
The scheduler for the discriminator.
:return: The scheduler for the discriminator.
:rtype: Scheduler
"""
return self.schedulers[1].instance return self.schedulers[1].instance

View File

@@ -110,8 +110,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta):
def loss_data(self, input_pts, output_pts): def loss_data(self, input_pts, output_pts):
""" """
Compute the data loss for the PINN solver by evaluating the loss Compute the data loss for the PINN solver by evaluating the loss
between the network's output and the true solution. This method between the network's output and the true solution. This method should
should only be overridden intentionally. not be overridden, if not intentionally.
:param LabelTensor input_pts: The input points to the neural network. :param LabelTensor input_pts: The input points to the neural network.
:param LabelTensor output_pts: The true solution to compare with the :param LabelTensor output_pts: The true solution to compare with the

View File

@@ -1,19 +1,17 @@
"""Module for ReducedOrderModelSolver""" """Module for ReducedOrderModelSolver"""
import torch import torch
from .supervised import SupervisedSolver from .supervised import SupervisedSolver
class ReducedOrderModelSolver(SupervisedSolver): class ReducedOrderModelSolver(SupervisedSolver):
r""" r"""
ReducedOrderModelSolver solver class. This class implements a Reduced Order Model solver class. This class implements the Reduced Order
Reduced Order Model solver, using user specified ``reduction_network`` and Model solver, using user specified ``reduction_network`` and
``interpolation_network`` to solve a specific ``problem``. ``interpolation_network`` to solve a specific ``problem``.
The Reduced Order Model approach aims to find The Reduced Order Model solver aims to find the solution
the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` of a differential problem:
of the differential problem:
.. math:: .. math::
@@ -23,13 +21,13 @@ class ReducedOrderModelSolver(SupervisedSolver):
\mathbf{x}\in\partial\Omega \mathbf{x}\in\partial\Omega
\end{cases} \end{cases}
This is done by using two neural networks. The ``reduction_network``, which This is done by means of two neural networks: the ``reduction_network``,
contains an encoder :math:`\mathcal{E}_{\rm{net}}`, a decoder which defines an encoder :math:`\mathcal{E}_{\rm{net}}`, and a decoder
:math:`\mathcal{D}_{\rm{net}}`; and an ``interpolation_network`` :math:`\mathcal{D}_{\rm{net}}`; and the ``interpolation_network``
:math:`\mathcal{I}_{\rm{net}}`. The input is assumed to be discretised in :math:`\mathcal{I}_{\rm{net}}`. The input is assumed to be discretised in
the spatial dimensions. the spatial dimensions.
The following loss function is minimized during training The following loss function is minimized during training:
.. math:: .. math::
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N \mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
@@ -39,49 +37,46 @@ class ReducedOrderModelSolver(SupervisedSolver):
\mathcal{D}_{\rm{net}}[\mathcal{E}_{\rm{net}}[\mathbf{u}(\mu_i)]] - \mathcal{D}_{\rm{net}}[\mathcal{E}_{\rm{net}}[\mathbf{u}(\mu_i)]] -
\mathbf{u}(\mu_i)) \mathbf{u}(\mu_i))
where :math:`\mathcal{L}` is a specific loss function, default where :math:`\mathcal{L}` is a specific loss function, typically the MSE:
Mean Square Error:
.. math:: .. math::
\mathcal{L}(v) = \| v \|^2_2. \mathcal{L}(v) = \| v \|^2_2.
.. seealso:: .. seealso::
**Original reference**: Hesthaven, Jan S., and Stefano Ubbiali. **Original reference**: Hesthaven, Jan S., and Stefano Ubbiali.
"Non-intrusive reduced order modeling of nonlinear problems "Non-intrusive reduced order modeling of nonlinear problems using
using neural networks." Journal of Computational neural networks."
Physics 363 (2018): 55-78. Journal of Computational Physics 363 (2018): 55-78.
DOI `10.1016/j.jcp.2018.02.037 DOI `10.1016/j.jcp.2018.02.037
<https://doi.org/10.1016/j.jcp.2018.02.037>`_. <https://doi.org/10.1016/j.jcp.2018.02.037>`_.
.. note:: .. note::
The specified ``reduction_network`` must contain two methods, The specified ``reduction_network`` must contain two methods, namely
namely ``encode`` for input encoding and ``decode`` for decoding the ``encode`` for input encoding, and ``decode`` for decoding the former
former result. The ``interpolation_network`` network ``forward`` output result. The ``interpolation_network`` network ``forward`` output
represents the interpolation of the latent space obtain with represents the interpolation of the latent space obtained with
``reduction_network.encode``. ``reduction_network.encode``.
.. note:: .. note::
This solver uses the end-to-end training strategy, i.e. the This solver uses the end-to-end training strategy, i.e. the
``reduction_network`` and ``interpolation_network`` are trained ``reduction_network`` and ``interpolation_network`` are trained
simultaneously. For reference on this trainig strategy look at: simultaneously. For reference on this trainig strategy look at the
Pichi, Federico, Beatriz Moya, and Jan S. Hesthaven. following:
..seealso::
**Original reference**: Pichi, Federico, Beatriz Moya, and Jan S.
Hesthaven.
"A graph convolutional autoencoder approach to model order reduction "A graph convolutional autoencoder approach to model order reduction
for parametrized PDEs." Journal of for parametrized PDEs."
Computational Physics 501 (2024): 112762. Journal of Computational Physics 501 (2024): 112762.
DOI DOI `10.1016/j.jcp.2024.112762
`10.1016/j.jcp.2024.112762 <https://doi.org/10.1016/ <https://doi.org/10.1016/j.jcp.2024.112762>`_.
j.jcp.2024.112762>`_.
.. warning:: .. warning::
This solver works only for data-driven model. Hence in the ``problem`` This solver works only for data-driven model. Hence in the ``problem``
definition the codition must only contain ``input`` definition the codition must only contain ``input``
(e.g. coefficient parameters, time parameters), and ``target``. (e.g. coefficient parameters, time parameters), and ``target``.
.. warning::
This solver does not currently support the possibility to pass
``extra_feature``.
""" """
def __init__( def __init__(
@@ -96,22 +91,28 @@ class ReducedOrderModelSolver(SupervisedSolver):
use_lt=True, use_lt=True,
): ):
""" """
Initialization of the :class:`ReducedOrderModelSolver` class.
:param AbstractProblem problem: The formualation of the problem. :param AbstractProblem problem: The formualation of the problem.
:param torch.nn.Module reduction_network: The reduction network used :param torch.nn.Module reduction_network: The reduction network used
for reducing the input space. It must contain two methods, for reducing the input space. It must contain two methods, namely
namely ``encode`` for input encoding and ``decode`` for decoding the ``encode`` for input encoding, and ``decode`` for decoding the
former result. former result.
:param torch.nn.Module interpolation_network: The interpolation network :param torch.nn.Module interpolation_network: The interpolation network
for interpolating the control parameters to latent space obtain by for interpolating the control parameters to latent space obtained by
the ``reduction_network`` encoding. the ``reduction_network`` encoding.
:param torch.nn.Module loss: The loss function used as minimizer, :param torch.nn.Module loss: The loss function to be minimized.
default :class:`torch.nn.MSELoss`. If `None`, the :class:`torch.nn.MSELoss` loss is used.
:param torch.optim.Optimizer optimizer: The neural network optimizer to Default is `None`.
use; default is :class:`torch.optim.Adam`. :param Optimizer optimizer: The optimizer to be used.
:param torch.optim.LRScheduler scheduler: Learning If `None`, the :class:`torch.optim.Adam`. optimizer is used.
rate scheduler. Default is ``None``.
:param WeightingInterface weighting: The loss weighting to use. :param Scheduler scheduler: Learning rate scheduler. If `None`,
:param bool use_lt: Using LabelTensors as input during training. the constant learning rate scheduler is used. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
Default is ``True``.
""" """
model = torch.nn.ModuleDict( model = torch.nn.ModuleDict(
{ {
@@ -146,10 +147,10 @@ class ReducedOrderModelSolver(SupervisedSolver):
def forward(self, x): def forward(self, x):
""" """
Forward pass implementation for the solver. It finds the encoder Forward pass implementation.
representation by calling ``interpolation_network.forward`` on the It computes the encoder representation by calling the forward method
input, and maps this representation to output space by calling of the ``interpolation_network`` on the input, and maps it to output
``reduction_network.decode``. space by calling the decode methode of the ``reduction_network``.
:param torch.Tensor x: Input tensor. :param torch.Tensor x: Input tensor.
:return: Solver solution. :return: Solver solution.
@@ -161,15 +162,14 @@ class ReducedOrderModelSolver(SupervisedSolver):
def loss_data(self, input_pts, output_pts): def loss_data(self, input_pts, output_pts):
""" """
The data loss for the ReducedOrderModelSolver solver. Compute the data loss by evaluating the loss between the network's
It computes the loss between output and the true solution. This method should not be overridden, if
the network output against the true solution. This function not intentionally.
should not be override if not intentionally.
:param LabelTensor input_tensor: The input to the neural networks. :param LabelTensor input_pts: The input points to the neural network.
:param LabelTensor output_tensor: The true solution to compare the :param LabelTensor output_pts: The true solution to compare with the
network solution. network's output.
:return: The residual loss averaged on the input coordinates :return: The supervised loss, averaged over the number of observations.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
# extract networks # extract networks

View File

@@ -14,17 +14,19 @@ from ..utils import check_consistency, labelize_forward
class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
""" """
SolverInterface base class. This class is a wrapper of LightningModule. Abstract base class for PINA solvers. All specific solvers should inherit
from this interface. This class is a wrapper of
:class:`~lightning.pytorch.LightningModule`.
""" """
def __init__(self, problem, weighting, use_lt): def __init__(self, problem, weighting, use_lt):
""" """
:param problem: A problem definition instance. Initialization of the :class:`SolverInterface` class.
:type problem: AbstractProblem
:param weighting: The loss weighting to use. :param AbstractProblem problem: The problem to be solved.
:type weighting: WeightingInterface :param WeightingInterface weighting: The weighting schema to be used.
:param use_lt: Using LabelTensors as input during training. If `None`, no weighting schema is used. Default is ``None``.
:type use_lt: bool :param bool use_lt: If ``True``, the solver uses LabelTensors as input.
""" """
super().__init__() super().__init__()
@@ -59,22 +61,24 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
self._pina_schedulers = None self._pina_schedulers = None
def _check_solver_consistency(self, problem): def _check_solver_consistency(self, problem):
"""
Check the consistency of the solver with the problem formulation.
:param AbstractProblem problem: The problem to be solved.
"""
for condition in problem.conditions.values(): for condition in problem.conditions.values():
check_consistency(condition, self.accepted_conditions_types) check_consistency(condition, self.accepted_conditions_types)
def _optimization_cycle(self, batch): def _optimization_cycle(self, batch):
""" """
Perform a private optimization cycle by computing the loss for each Aggregate the loss for each condition in the batch.
condition in the given batch. The loss are later aggregated using the
specific weighting schema.
:param batch: A batch of data, where each element is a tuple containing :param list[tuple[str, dict]] batch: A batch of data. Each element is a
a condition name and a dictionary of points. tuple containing a condition name and a dictionary of points.
:type batch: list of tuples (str, dict) :return: The computed loss for the all conditions in the batch, casted
:return: The computed loss for the all conditions in the batch, to a subclass of `torch.Tensor`. It should return a dict containing
cast to a subclass of `torch.Tensor`. It should return a dict the condition name and the associated scalar loss.
containing the condition name and the associated scalar loss. :rtype: dict
:rtype: dict(torch.Tensor)
""" """
losses = self.optimization_cycle(batch) losses = self.optimization_cycle(batch)
for name, value in losses.items(): for name, value in losses.items():
@@ -88,9 +92,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
""" """
Solver training step. Solver training step.
:param batch: The batch element in the dataloader. :param list[tuple[str, dict]] batch: The batch element in the dataloader.
:type batch: tuple :return: The loss of the training step.
:return: The sum of the loss functions.
:rtype: LabelTensor :rtype: LabelTensor
""" """
loss = self._optimization_cycle(batch=batch) loss = self._optimization_cycle(batch=batch)
@@ -101,8 +104,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
""" """
Solver validation step. Solver validation step.
:param batch: The batch element in the dataloader. :param list[tuple[str, dict]] batch: The batch element in the dataloader.
:type batch: tuple
""" """
loss = self._optimization_cycle(batch=batch) loss = self._optimization_cycle(batch=batch)
self.store_log("val_loss", loss, self.get_batch_size(batch)) self.store_log("val_loss", loss, self.get_batch_size(batch))
@@ -111,15 +113,18 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
""" """
Solver test step. Solver test step.
:param batch: The batch element in the dataloader. :param list[tuple[str, dict]] batch: The batch element in the dataloader.
:type batch: tuple
""" """
loss = self._optimization_cycle(batch=batch) loss = self._optimization_cycle(batch=batch)
self.store_log("test_loss", loss, self.get_batch_size(batch)) self.store_log("test_loss", loss, self.get_batch_size(batch))
def store_log(self, name, value, batch_size): def store_log(self, name, value, batch_size):
""" """
TODO Store the log of the solver.
:param str name: The name of the log.
:param torch.Tensor value: The value of the log.
:param int batch_size: The size of the batch.
""" """
self.log( self.log(
@@ -132,49 +137,59 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
@abstractmethod @abstractmethod
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
""" """
TODO Abstract method for the forward pass implementation.
""" """
@abstractmethod @abstractmethod
def optimization_cycle(self, batch): def optimization_cycle(self, batch):
""" """
Perform an optimization cycle by computing the loss for each condition The optimization cycle for the solvers.
in the given batch.
:param batch: A batch of data, where each element is a tuple containing :param list[tuple[str, dict]] batch: The batch element in the dataloader.
a condition name and a dictionary of points. :return: The computed loss for the all conditions in the batch, casted
:type batch: list of tuples (str, dict) to a subclass of `torch.Tensor`. It should return a dict containing
:return: The computed loss for the all conditions in the batch, the condition name and the associated scalar loss.
cast to a subclass of `torch.Tensor`. It should return a dict :rtype: dict
containing the condition name and the associated scalar loss.
:rtype: dict(torch.Tensor)
""" """
@property @property
def problem(self): def problem(self):
""" """
The problem formulation. The problem instance.
:return: The problem instance.
:rtype: :class:`~pina.problem.abstract_problem.AbstractProblem`
""" """
return self._pina_problem return self._pina_problem
@property @property
def use_lt(self): def use_lt(self):
""" """
Using LabelTensor in training. Using LabelTensors as input during training.
:return: The use_lt attribute.
:rtype: bool
""" """
return self._use_lt return self._use_lt
@property @property
def weighting(self): def weighting(self):
""" """
The weighting mechanism. The weighting schema.
:return: The weighting schema.
:rtype: :class:`~pina.loss.weighting_interface.WeightingInterface`
""" """
return self._pina_weighting return self._pina_weighting
@staticmethod @staticmethod
def get_batch_size(batch): def get_batch_size(batch):
""" """
TODO Get the batch size.
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
:return: The size of the batch.
:rtype: int
""" """
batch_size = 0 batch_size = 0
@@ -185,23 +200,29 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
@staticmethod @staticmethod
def default_torch_optimizer(): def default_torch_optimizer():
""" """
TODO Set the default optimizer to :class:`torch.optim.Adam`.
"""
:return: The default optimizer.
:rtype: Optimizer
"""
return TorchOptimizer(torch.optim.Adam, lr=0.001) return TorchOptimizer(torch.optim.Adam, lr=0.001)
@staticmethod @staticmethod
def default_torch_scheduler(): def default_torch_scheduler():
""" """
TODO Set the default scheduler to
:class:`torch.optim.lr_scheduler.ConstantLR`.
:return: The default scheduler.
:rtype: Scheduler
""" """
return TorchScheduler(torch.optim.lr_scheduler.ConstantLR) return TorchScheduler(torch.optim.lr_scheduler.ConstantLR)
def on_train_start(self): def on_train_start(self):
""" """
Hook that is called before training begins. This method is called at the start of the training process to compile
Used to compile the model if the trainer is set to compile. the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``.
""" """
super().on_train_start() super().on_train_start()
if self.trainer.compile: if self.trainer.compile:
@@ -209,8 +230,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
def on_test_start(self): def on_test_start(self):
""" """
Hook that is called before training begins. This method is called at the start of the test process to compile
Used to compile the model if the trainer is set to compile. the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``.
""" """
super().on_train_start() super().on_train_start()
if self.trainer.compile and not self._check_already_compiled(): if self.trainer.compile and not self._check_already_compiled():
@@ -218,7 +239,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
def _check_already_compiled(self): def _check_already_compiled(self):
""" """
TODO Check if the model is already compiled.
:return: ``True`` if the model is already compiled, ``False`` otherwise.
:rtype: bool
""" """
models = self._pina_models models = self._pina_models
@@ -234,7 +258,12 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
@staticmethod @staticmethod
def _perform_compilation(model): def _perform_compilation(model):
""" """
TODO Perform the compilation of the model.
:param torch.nn.Module model: The model to compile.
:raises Exception: If the compilation fails.
:return: The compiled model.
:rtype: torch.nn.Module
""" """
model_device = next(model.parameters()).device model_device = next(model.parameters()).device
@@ -249,8 +278,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta):
class SingleSolverInterface(SolverInterface, metaclass=ABCMeta): class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
"""TODO""" """
Base class for PINA solvers using a single :class:`torch.nn.Module`.
"""
def __init__( def __init__(
self, self,
problem, problem,
@@ -261,14 +291,18 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
use_lt=True, use_lt=True,
): ):
""" """
:param problem: A problem definition instance. Initialization of the :class:`SingleSolverInterface` class.
:type problem: AbstractProblem
:param model: A torch nn.Module instances. :param AbstractProblem problem: The problem to be solved.
:type model: torch.nn.Module :param torch.nn.Module model: The neural network model to be used.
:param Optimizer optimizers: A neural network optimizers to use. :param Optimizer optimizer: The optimizer to be used.
:param Scheduler optimizers: A neural network scheduler to use. If `None`, the Adam optimizer is used. Default is ``None``.
:param WeightingInterface weighting: The loss weighting to use. :param Scheduler scheduler: The scheduler to be used.
:param bool use_lt: Using LabelTensors as input during training. If `None`, the constant learning rate scheduler is used.
Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
""" """
if optimizer is None: if optimizer is None:
optimizer = self.default_torch_optimizer() optimizer = self.default_torch_optimizer()
@@ -292,11 +326,12 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
def forward(self, x): def forward(self, x):
""" """
Forward pass implementation for the solver. Forward pass implementation.
:param torch.Tensor x: Input tensor. :param x: Input tensor.
:type x: torch.Tensor | LabelTensor
:return: Solver solution. :return: Solver solution.
:rtype: torch.Tensor :rtype: torch.Tensor | LabelTensor
""" """
x = self.model(x) x = self.model(x)
return x return x
@@ -305,7 +340,7 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
""" """
Optimizer configuration for the solver. Optimizer configuration for the solver.
:return: The optimizers and the schedulers :return: The optimizer and the scheduler
:rtype: tuple(list, list) :rtype: tuple(list, list)
""" """
self.optimizer.hook(self.model.parameters()) self.optimizer.hook(self.model.parameters())
@@ -313,44 +348,61 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta):
return ([self.optimizer.instance], [self.scheduler.instance]) return ([self.optimizer.instance], [self.scheduler.instance])
def _compile_model(self): def _compile_model(self):
"""
Compile the model.
"""
if isinstance(self._pina_models[0], torch.nn.ModuleDict): if isinstance(self._pina_models[0], torch.nn.ModuleDict):
self._compile_module_dict() self._compile_module_dict()
else: else:
self._compile_single_model() self._compile_single_model()
def _compile_module_dict(self): def _compile_module_dict(self):
"""
Compile the model if it is a :class:`torch.nn.ModuleDict`.
"""
for name, model in self._pina_models[0].items(): for name, model in self._pina_models[0].items():
self._pina_models[0][name] = self._perform_compilation(model) self._pina_models[0][name] = self._perform_compilation(model)
def _compile_single_model(self): def _compile_single_model(self):
"""
Compile the model if it is a single :class:`torch.nn.Module`.
"""
self._pina_models[0] = self._perform_compilation(self._pina_models[0]) self._pina_models[0] = self._perform_compilation(self._pina_models[0])
@property @property
def model(self): def model(self):
""" """
Model for training. The model used for training.
:return: The model used for training.
:rtype: torch.nn.Module
""" """
return self._pina_models[0] return self._pina_models[0]
@property @property
def scheduler(self): def scheduler(self):
""" """
Scheduler for training. The scheduler used for training.
:return: The scheduler used for training.
:rtype: Scheduler
""" """
return self._pina_schedulers[0] return self._pina_schedulers[0]
@property @property
def optimizer(self): def optimizer(self):
""" """
Optimizer for training. The optimizer used for training.
:return: The optimizer used for training.
:rtype: Optimizer
""" """
return self._pina_optimizers[0] return self._pina_optimizers[0]
class MultiSolverInterface(SolverInterface, metaclass=ABCMeta): class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
""" """
Multiple Solver base class. This class inherits is a wrapper of Base class for PINA solvers using multiple :class:`torch.nn.Module`.
SolverInterface class
""" """
def __init__( def __init__(
@@ -363,16 +415,22 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
use_lt=True, use_lt=True,
): ):
""" """
:param problem: A problem definition instance. Initialization of the :class:`MultiSolverInterface` class.
:type problem: AbstractProblem
:param models: Multiple torch nn.Module instances. :param AbstractProblem problem: The problem to be solved.
:param models: The neural network models to be used.
:type model: list[torch.nn.Module] | tuple[torch.nn.Module] :type model: list[torch.nn.Module] | tuple[torch.nn.Module]
:param list(Optimizer) optimizers: A list of neural network :param list[Optimizer] optimizers: The optimizers to be used.
optimizers to use. If `None`, the Adam optimizer is used for all models.
:param list(Scheduler) optimizers: A list of neural network Default is ``None``.
schedulers to use. :param list[Scheduler] schedulers: The schedulers to be used.
:param WeightingInterface weighting: The loss weighting to use. If `None`, the constant learning rate scheduler is used for all the
:param bool use_lt: Using LabelTensors as input during training. models. Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
:raises ValueError: If the models are not a list or tuple with length
greater than one.
""" """
if not isinstance(models, (list, tuple)) or len(models) < 2: if not isinstance(models, (list, tuple)) or len(models) < 2:
raise ValueError( raise ValueError(
@@ -418,9 +476,10 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
self._pina_schedulers = schedulers self._pina_schedulers = schedulers
def configure_optimizers(self): def configure_optimizers(self):
"""Optimizer configuration for the solver. """
Optimizer configuration for the solver.
:return: The optimizers and the schedulers :return: The optimizer and the scheduler
:rtype: tuple(list, list) :rtype: tuple(list, list)
""" """
for optimizer, scheduler, model in zip( for optimizer, scheduler, model in zip(
@@ -435,6 +494,9 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
) )
def _compile_model(self): def _compile_model(self):
"""
Compile the model.
"""
for i, model in enumerate(self._pina_models): for i, model in enumerate(self._pina_models):
if not isinstance(model, torch.nn.ModuleDict): if not isinstance(model, torch.nn.ModuleDict):
self._pina_models[i] = self._perform_compilation(model) self._pina_models[i] = self._perform_compilation(model)
@@ -442,17 +504,29 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta):
@property @property
def models(self): def models(self):
""" """
The torch model.""" The models used for training.
:return: The models used for training.
:rtype: torch.nn.ModuleList
"""
return self._pina_models return self._pina_models
@property @property
def optimizers(self): def optimizers(self):
""" """
The torch model.""" The optimizers used for training.
:return: The optimizers used for training.
:rtype: list[Optimizer]
"""
return self._pina_optimizers return self._pina_optimizers
@property @property
def schedulers(self): def schedulers(self):
""" """
The torch model.""" The schedulers used for training.
:return: The schedulers used for training.
:rtype: list[Scheduler]
"""
return self._pina_schedulers return self._pina_schedulers

View File

@@ -1,4 +1,4 @@
"""Module for SupervisedSolver""" """Module for the Supervised Solver."""
import torch import torch
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
@@ -10,31 +10,28 @@ from ..condition import InputTargetCondition
class SupervisedSolver(SingleSolverInterface): class SupervisedSolver(SingleSolverInterface):
r""" r"""
SupervisedSolver solver class. This class implements a SupervisedSolver, Supervised Solver solver class. This class implements a Supervised Solver,
using a user specified ``model`` to solve a specific ``problem``. using a user specified ``model`` to solve a specific ``problem``.
The Supervised Solver class aims to find The Supervised Solver class aims to find a map between the input
a map between the input :math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m` :math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m` and the output
and the output :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`. The input :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`.
can be discretised in space (as in :obj:`~pina.solver.rom.ROMe2eSolver`),
or not (e.g. when training Neural Operators).
Given a model :math:`\mathcal{M}`, the following loss function is Given a model :math:`\mathcal{M}`, the following loss function is
minimized during training: minimized during training:
.. math:: .. math::
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N \mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
\mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{v}_i)) \mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{v}_i)),
where :math:`\mathcal{L}` is a specific loss function, where :math:`\mathcal{L}` is a specific loss function, typically the MSE:
default Mean Square Error:
.. math:: .. math::
\mathcal{L}(v) = \| v \|^2_2. \mathcal{L}(v) = \| v \|^2_2.
In this context :math:`\mathbf{u}_i` and :math:`\mathbf{v}_i` means that In this context, :math:`\mathbf{u}_i` and :math:`\mathbf{v}_i` indicates
we are seeking to approximate multiple (discretised) functions given the will to approximate multiple (discretised) functions given multiple
multiple (discretised) input functions. (discretised) input functions.
""" """
accepted_conditions_types = InputTargetCondition accepted_conditions_types = InputTargetCondition
@@ -50,16 +47,22 @@ class SupervisedSolver(SingleSolverInterface):
use_lt=True, use_lt=True,
): ):
""" """
:param AbstractProblem problem: The formualation of the problem. Initialization of the :class:`SupervisedSolver` class.
:param torch.nn.Module model: The neural network model to use.
:param torch.nn.Module loss: The loss function used as minimizer, :param AbstractProblem problem: The problem to be solved.
default :class:`torch.nn.MSELoss`. :param torch.nn.Module model: The neural network model to be used.
:param torch.optim.Optimizer optimizer: The neural network optimizer to :param torch.nn.Module loss: The loss function to be minimized.
use; default is :class:`torch.optim.Adam`. If `None`, the Mean Squared Error (MSE) loss is used.
:param torch.optim.LRScheduler scheduler: Learning Default is `None`.
rate scheduler. :param torch.optim.Optimizer optimizer: The optimizer to be used.
:param WeightingInterface weighting: The loss weighting to use. If `None`, the Adam optimizer is used. Default is ``None``.
:param bool use_lt: Using LabelTensors as input during training. :param torch.optim.LRScheduler scheduler: Learning rate scheduler.
If `None`, the constant learning rate scheduler is used.
Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
Default is ``True``.
""" """
if loss is None: if loss is None:
loss = torch.nn.MSELoss() loss = torch.nn.MSELoss()
@@ -81,16 +84,13 @@ class SupervisedSolver(SingleSolverInterface):
def optimization_cycle(self, batch): def optimization_cycle(self, batch):
""" """
Perform an optimization cycle by computing the loss for each condition The optimization cycle for the solvers.
in the given batch.
:param batch: A batch of data, where each element is a tuple containing :param list[tuple[str, dict]] batch: The batch element in the dataloader.
a condition name and a dictionary of points. :return: The computed loss for the all conditions in the batch, casted
:type batch: list of tuples (str, dict) to a subclass of `torch.Tensor`. It should return a dict containing
:return: The computed loss for the all conditions in the batch, the condition name and the associated scalar loss.
cast to a subclass of `torch.Tensor`. It should return a dict :rtype: dict
containing the condition name and the associated scalar loss.
:rtype: dict(torch.Tensor)
""" """
condition_loss = {} condition_loss = {}
for condition_name, points in batch: for condition_name, points in batch:
@@ -105,16 +105,16 @@ class SupervisedSolver(SingleSolverInterface):
def loss_data(self, input_pts, output_pts): def loss_data(self, input_pts, output_pts):
""" """
The data loss for the Supervised solver. It computes the loss between Compute the data loss for the Supervised solver by evaluating the loss
the network output against the true solution. This function between the network's output and the true solution. This method should
should not be override if not intentionally. not be overridden, if not intentionally.
:param input_pts: The input to the neural networks. :param input_pts: The input points to the neural network.
:type input_pts: LabelTensor | torch.Tensor :type input_pts: LabelTensor | torch.Tensor
:param output_pts: The true solution to compare the :param output_pts: The true solution to compare with the network's
network solution. output.
:type output_pts: LabelTensor | torch.Tensor :type output_pts: LabelTensor | torch.Tensor
:return: The residual loss. :return: The supervised loss, averaged over the number of observations.
:rtype: torch.Tensor :rtype: torch.Tensor
""" """
return self._loss(self.forward(input_pts), output_pts) return self._loss(self.forward(input_pts), output_pts)
@@ -122,6 +122,9 @@ class SupervisedSolver(SingleSolverInterface):
@property @property
def loss(self): def loss(self):
""" """
Loss for training. The loss function to be minimized.
:return: The loss function to be minimized.
:rtype: torch.nn.Module
""" """
return self._loss return self._loss