fix doc solver

This commit is contained in:
giovanni
2025-03-13 19:05:15 +01:00
committed by FilippoOlivo
parent fd2b50fc06
commit 5e6aa61592
6 changed files with 374 additions and 251 deletions

View File

@@ -1,4 +1,4 @@
"""Module for SupervisedSolver"""
"""Module for the Supervised Solver."""
import torch
from torch.nn.modules.loss import _Loss
@@ -10,31 +10,28 @@ from ..condition import InputTargetCondition
class SupervisedSolver(SingleSolverInterface):
r"""
SupervisedSolver solver class. This class implements a SupervisedSolver,
Supervised Solver solver class. This class implements a Supervised Solver,
using a user specified ``model`` to solve a specific ``problem``.
The Supervised Solver class aims to find
a map between the input :math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m`
and the output :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`. The input
can be discretised in space (as in :obj:`~pina.solver.rom.ROMe2eSolver`),
or not (e.g. when training Neural Operators).
The Supervised Solver class aims to find a map between the input
:math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m` and the output
:math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`.
Given a model :math:`\mathcal{M}`, the following loss function is
minimized during training:
.. math::
\mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N
\mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{v}_i))
\mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{v}_i)),
where :math:`\mathcal{L}` is a specific loss function,
default Mean Square Error:
where :math:`\mathcal{L}` is a specific loss function, typically the MSE:
.. math::
\mathcal{L}(v) = \| v \|^2_2.
In this context :math:`\mathbf{u}_i` and :math:`\mathbf{v}_i` means that
we are seeking to approximate multiple (discretised) functions given
multiple (discretised) input functions.
In this context, :math:`\mathbf{u}_i` and :math:`\mathbf{v}_i` indicates
the will to approximate multiple (discretised) functions given multiple
(discretised) input functions.
"""
accepted_conditions_types = InputTargetCondition
@@ -50,16 +47,22 @@ class SupervisedSolver(SingleSolverInterface):
use_lt=True,
):
"""
:param AbstractProblem problem: The formualation of the problem.
:param torch.nn.Module model: The neural network model to use.
:param torch.nn.Module loss: The loss function used as minimizer,
default :class:`torch.nn.MSELoss`.
:param torch.optim.Optimizer optimizer: The neural network optimizer to
use; default is :class:`torch.optim.Adam`.
:param torch.optim.LRScheduler scheduler: Learning
rate scheduler.
:param WeightingInterface weighting: The loss weighting to use.
:param bool use_lt: Using LabelTensors as input during training.
Initialization of the :class:`SupervisedSolver` class.
:param AbstractProblem problem: The problem to be solved.
:param torch.nn.Module model: The neural network model to be used.
:param torch.nn.Module loss: The loss function to be minimized.
If `None`, the Mean Squared Error (MSE) loss is used.
Default is `None`.
:param torch.optim.Optimizer optimizer: The optimizer to be used.
If `None`, the Adam optimizer is used. Default is ``None``.
:param torch.optim.LRScheduler scheduler: Learning rate scheduler.
If `None`, the constant learning rate scheduler is used.
Default is ``None``.
:param WeightingInterface weighting: The weighting schema to be used.
If `None`, no weighting schema is used. Default is ``None``.
:param bool use_lt: If ``True``, the solver uses LabelTensors as input.
Default is ``True``.
"""
if loss is None:
loss = torch.nn.MSELoss()
@@ -81,16 +84,13 @@ class SupervisedSolver(SingleSolverInterface):
def optimization_cycle(self, batch):
"""
Perform an optimization cycle by computing the loss for each condition
in the given batch.
The optimization cycle for the solvers.
:param batch: A batch of data, where each element is a tuple containing
a condition name and a dictionary of points.
:type batch: list of tuples (str, dict)
:return: The computed loss for the all conditions in the batch,
cast to a subclass of `torch.Tensor`. It should return a dict
containing the condition name and the associated scalar loss.
:rtype: dict(torch.Tensor)
:param list[tuple[str, dict]] batch: The batch element in the dataloader.
:return: The computed loss for the all conditions in the batch, casted
to a subclass of `torch.Tensor`. It should return a dict containing
the condition name and the associated scalar loss.
:rtype: dict
"""
condition_loss = {}
for condition_name, points in batch:
@@ -105,16 +105,16 @@ class SupervisedSolver(SingleSolverInterface):
def loss_data(self, input_pts, output_pts):
"""
The data loss for the Supervised solver. It computes the loss between
the network output against the true solution. This function
should not be override if not intentionally.
Compute the data loss for the Supervised solver by evaluating the loss
between the network's output and the true solution. This method should
not be overridden, if not intentionally.
:param input_pts: The input to the neural networks.
:param input_pts: The input points to the neural network.
:type input_pts: LabelTensor | torch.Tensor
:param output_pts: The true solution to compare the
network solution.
:param output_pts: The true solution to compare with the network's
output.
:type output_pts: LabelTensor | torch.Tensor
:return: The residual loss.
:return: The supervised loss, averaged over the number of observations.
:rtype: torch.Tensor
"""
return self._loss(self.forward(input_pts), output_pts)
@@ -122,6 +122,9 @@ class SupervisedSolver(SingleSolverInterface):
@property
def loss(self):
"""
Loss for training.
The loss function to be minimized.
:return: The loss function to be minimized.
:rtype: torch.nn.Module
"""
return self._loss