diff --git a/pina/solver/__init__.py b/pina/solver/__init__.py index 7a10cf9..cdca62d 100644 --- a/pina/solver/__init__.py +++ b/pina/solver/__init__.py @@ -1,6 +1,4 @@ -""" -TODO -""" +"""Module for the solver classes.""" __all__ = [ "SolverInterface", diff --git a/pina/solver/garom.py b/pina/solver/garom.py index d023cf8..1edaab2 100644 --- a/pina/solver/garom.py +++ b/pina/solver/garom.py @@ -1,4 +1,4 @@ -"""Module for GAROM""" +"""Module for the GAROM solver.""" import torch from torch.nn.modules.loss import _Loss @@ -10,9 +10,9 @@ from ..loss import LossInterface, PowerLoss class GAROM(MultiSolverInterface): """ - GAROM solver class. This class implements Generative Adversarial - Reduced Order Model solver, using user specified ``models`` to solve - a specific order reduction``problem``. + GAROM solver class. This class implements Generative Adversarial Reduced + Order Model solver, using user specified ``models`` to solve a specific + order reduction ``problem``. .. seealso:: @@ -39,40 +39,28 @@ class GAROM(MultiSolverInterface): regularizer=False, ): """ - :param AbstractProblem problem: The formualation of the problem. - :param torch.nn.Module generator: The neural network model to use - for the generator. - :param torch.nn.Module discriminator: The neural network model to use + Initialization of the :class:`GAROM` class. + + :param AbstractProblem problem: The formulation of the problem. + :param torch.nn.Module generator: The generator model. + :param torch.nn.Module discriminator: The discriminator model. + :param torch.nn.Module loss: The loss function to be minimized. + If ``None``, ``PowerLoss(p=1)`` is used. Default is ``None``. + :param Optimizer optimizer_generator: The optimizer for the generator. + If `None`, the Adam optimizer is used. Default is ``None``. + :param Optimizer optimizer_discriminator: The optimizer for the + discriminator. If `None`, the Adam optimizer is used. + Default is ``None``. + :param Scheduler scheduler_generator: The learning rate scheduler for + the generator. + :param Scheduler scheduler_discriminator: The learning rate scheduler for the discriminator. - :param torch.nn.Module loss: The loss function used as minimizer, - default ``None``. If ``loss`` is ``None`` the defualt - ``PowerLoss(p=1)`` is used, as in the original paper. - :param Optimizer optimizer_generator: The neural - network optimizer to use for the generator network - , default is `torch.optim.Adam`. - :param Optimizer optimizer_discriminator: The neural - network optimizer to use for the discriminator network - , default is `torch.optim.Adam`. - :param Scheduler scheduler_generator: Learning - rate scheduler for the generator. - :param Scheduler scheduler_discriminator: Learning - rate scheduler for the discriminator. - :param dict scheduler_discriminator_kwargs: LR scheduler constructor - keyword args. - :param gamma: Ratio of expected loss for generator and discriminator, - defaults to 0.3. - :type gamma: float - :param lambda_k: Learning rate for control theory optimization, - defaults to 0.001. - :type lambda_k: float - :param regularizer: Regularization term in the GAROM loss, - defaults to False. - :type regularizer: bool - - .. warning:: - The algorithm works only for data-driven model. Hence in the - ``problem`` definition the codition must only contain ``input`` - (e.g. coefficient parameters, time parameters), and ``target``. + :param float gamma: Ratio of expected loss for generator and + discriminator. Default is ``0.3``. + :param float lambda_k: Learning rate for control theory optimization. + Default is ``0.001``. + :param bool regularizer: If ``True``, uses a regularization term in the + GAROM loss. Default is ``False``. """ # set loss @@ -112,19 +100,15 @@ class GAROM(MultiSolverInterface): def forward(self, x, mc_steps=20, variance=False): """ - Forward step for GAROM solver + Forward pass implementation. - :param x: The input tensor. - :type x: torch.Tensor - :param mc_steps: Number of montecarlo samples to approximate the - expected value, defaults to 20. - :type mc_steps: int - :param variance: Returining also the sample variance of the solution, - defaults to False. - :type variance: bool + :param torch.Tensor x: The input tensor. + :param int mc_steps: Number of Montecarlo samples to approximate the + expected value. Default is ``20``. + :param bool variance: If ``True``, the method returns also the variance + of the solution. Default is ``False``. :return: The expected value of the generator distribution. If - ``variance=True`` also the - sample variance is returned. + ``variance=True``, the method returns also the variance. :rtype: torch.Tensor | tuple(torch.Tensor, torch.Tensor) """ @@ -142,13 +126,24 @@ class GAROM(MultiSolverInterface): return mean def sample(self, x): - """TODO""" + """ + Sample from the generator distribution. + + :param torch.Tensor x: The input tensor. + :return: The generated sample. + :rtype: torch.Tensor + """ # sampling return self.generator(x) def _train_generator(self, parameters, snapshots): """ - Private method to train the generator network. + Train the generator model. + + :param torch.Tensor parameters: The input tensor. + :param torch.Tensor snapshots: The target tensor. + :return: The residual loss and the generator loss. + :rtype: tuple(torch.Tensor, torch.Tensor) """ optimizer = self.optimizer_generator optimizer.zero_grad() @@ -170,16 +165,13 @@ class GAROM(MultiSolverInterface): def on_train_batch_end(self, outputs, batch, batch_idx): """ - This method is called at the end of each training batch, and ovverides - the PytorchLightining implementation for logging the checkpoints. + This method is called at the end of each training batch and overrides + the PyTorch Lightning implementation to log checkpoints. - :param torch.Tensor outputs: The output from the model for the - current batch. - :param tuple batch: The current batch of data. + :param torch.Tensor outputs: The ``model``'s output for the current + batch. + :param dict batch: The current batch of data. :param int batch_idx: The index of the current batch. - :return: Whatever is returned by the parent - method ``on_train_batch_end``. - :rtype: Any """ # increase by one the counter of optimization to save loggers ( @@ -190,7 +182,12 @@ class GAROM(MultiSolverInterface): def _train_discriminator(self, parameters, snapshots): """ - Private method to train the discriminator network. + Train the discriminator model. + + :param torch.Tensor parameters: The input tensor. + :param torch.Tensor snapshots: The target tensor. + :return: The residual loss and the generator loss. + :rtype: tuple(torch.Tensor, torch.Tensor) """ optimizer = self.optimizer_discriminator optimizer.zero_grad() @@ -215,8 +212,15 @@ class GAROM(MultiSolverInterface): def _update_weights(self, d_loss_real, d_loss_fake): """ - Private method to Update the weights of the generator and discriminator - networks. + Update the weights of the generator and discriminator models. + + :param torch.Tensor d_loss_real: The discriminator loss computed on + dataset samples. + :param torch.Tensor d_loss_fake: The discriminator loss computed on + generated samples. + :return: The difference between the loss computed on the dataset samples + and the loss computed on the generated samples. + :rtype: torch.Tensor """ diff = torch.mean(self.gamma * d_loss_real - d_loss_fake) @@ -227,11 +231,11 @@ class GAROM(MultiSolverInterface): return diff def optimization_cycle(self, batch): - """GAROM solver training step. + """ + The optimization cycle for the GAROM solver. - :param batch: The batch element in the dataloader. - :type batch: tuple - :return: The sum of the loss functions. + :param tuple batch: The batch element in the dataloader. + :return: The loss of the optimization cycle. :rtype: LabelTensor """ condition_loss = {} @@ -258,6 +262,13 @@ class GAROM(MultiSolverInterface): return condition_loss def validation_step(self, batch): + """ + The validation step for the PINN solver. + + :param dict batch: The batch of data to use in the validation step. + :return: The loss of the validation step. + :rtype: torch.Tensor + """ condition_loss = {} for condition_name, points in batch: parameters, snapshots = ( @@ -273,6 +284,13 @@ class GAROM(MultiSolverInterface): return loss def test_step(self, batch): + """ + The test step for the PINN solver. + + :param dict batch: The batch of data to use in the test step. + :return: The loss of the test step. + :rtype: torch.Tensor + """ condition_loss = {} for condition_name, points in batch: parameters, snapshots = ( @@ -289,30 +307,60 @@ class GAROM(MultiSolverInterface): @property def generator(self): - """TODO""" + """ + The generator model. + + :return: The generator model. + :rtype: torch.nn.Module + """ return self.models[0] @property def discriminator(self): - """TODO""" + """ + The discriminator model. + + :return: The discriminator model. + :rtype: torch.nn.Module + """ return self.models[1] @property def optimizer_generator(self): - """TODO""" + """ + The optimizer for the generator. + + :return: The optimizer for the generator. + :rtype: Optimizer + """ return self.optimizers[0].instance @property def optimizer_discriminator(self): - """TODO""" + """ + The optimizer for the discriminator. + + :return: The optimizer for the discriminator. + :rtype: Optimizer + """ return self.optimizers[1].instance @property def scheduler_generator(self): - """TODO""" + """ + The scheduler for the generator. + + :return: The scheduler for the generator. + :rtype: Scheduler + """ return self.schedulers[0].instance @property def scheduler_discriminator(self): - """TODO""" + """ + The scheduler for the discriminator. + + :return: The scheduler for the discriminator. + :rtype: Scheduler + """ return self.schedulers[1].instance diff --git a/pina/solver/physic_informed_solver/pinn_interface.py b/pina/solver/physic_informed_solver/pinn_interface.py index a1f1864..248822e 100644 --- a/pina/solver/physic_informed_solver/pinn_interface.py +++ b/pina/solver/physic_informed_solver/pinn_interface.py @@ -110,8 +110,8 @@ class PINNInterface(SolverInterface, metaclass=ABCMeta): def loss_data(self, input_pts, output_pts): """ Compute the data loss for the PINN solver by evaluating the loss - between the network's output and the true solution. This method - should only be overridden intentionally. + between the network's output and the true solution. This method should + not be overridden, if not intentionally. :param LabelTensor input_pts: The input points to the neural network. :param LabelTensor output_pts: The true solution to compare with the diff --git a/pina/solver/reduced_order_model.py b/pina/solver/reduced_order_model.py index c80556b..1b61f06 100644 --- a/pina/solver/reduced_order_model.py +++ b/pina/solver/reduced_order_model.py @@ -1,19 +1,17 @@ """Module for ReducedOrderModelSolver""" import torch - from .supervised import SupervisedSolver class ReducedOrderModelSolver(SupervisedSolver): r""" - ReducedOrderModelSolver solver class. This class implements a - Reduced Order Model solver, using user specified ``reduction_network`` and + Reduced Order Model solver class. This class implements the Reduced Order + Model solver, using user specified ``reduction_network`` and ``interpolation_network`` to solve a specific ``problem``. - The Reduced Order Model approach aims to find - the solution :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` - of the differential problem: + The Reduced Order Model solver aims to find the solution + :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m` of a differential problem: .. math:: @@ -23,13 +21,13 @@ class ReducedOrderModelSolver(SupervisedSolver): \mathbf{x}\in\partial\Omega \end{cases} - This is done by using two neural networks. The ``reduction_network``, which - contains an encoder :math:`\mathcal{E}_{\rm{net}}`, a decoder - :math:`\mathcal{D}_{\rm{net}}`; and an ``interpolation_network`` + This is done by means of two neural networks: the ``reduction_network``, + which defines an encoder :math:`\mathcal{E}_{\rm{net}}`, and a decoder + :math:`\mathcal{D}_{\rm{net}}`; and the ``interpolation_network`` :math:`\mathcal{I}_{\rm{net}}`. The input is assumed to be discretised in the spatial dimensions. - The following loss function is minimized during training + The following loss function is minimized during training: .. math:: \mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N @@ -39,49 +37,46 @@ class ReducedOrderModelSolver(SupervisedSolver): \mathcal{D}_{\rm{net}}[\mathcal{E}_{\rm{net}}[\mathbf{u}(\mu_i)]] - \mathbf{u}(\mu_i)) - where :math:`\mathcal{L}` is a specific loss function, default - Mean Square Error: + where :math:`\mathcal{L}` is a specific loss function, typically the MSE: .. math:: \mathcal{L}(v) = \| v \|^2_2. - .. seealso:: **Original reference**: Hesthaven, Jan S., and Stefano Ubbiali. - "Non-intrusive reduced order modeling of nonlinear problems - using neural networks." Journal of Computational - Physics 363 (2018): 55-78. + "Non-intrusive reduced order modeling of nonlinear problems using + neural networks." + Journal of Computational Physics 363 (2018): 55-78. DOI `10.1016/j.jcp.2018.02.037 `_. .. note:: - The specified ``reduction_network`` must contain two methods, - namely ``encode`` for input encoding and ``decode`` for decoding the - former result. The ``interpolation_network`` network ``forward`` output - represents the interpolation of the latent space obtain with + The specified ``reduction_network`` must contain two methods, namely + ``encode`` for input encoding, and ``decode`` for decoding the former + result. The ``interpolation_network`` network ``forward`` output + represents the interpolation of the latent space obtained with ``reduction_network.encode``. .. note:: This solver uses the end-to-end training strategy, i.e. the ``reduction_network`` and ``interpolation_network`` are trained - simultaneously. For reference on this trainig strategy look at: - Pichi, Federico, Beatriz Moya, and Jan S. Hesthaven. + simultaneously. For reference on this trainig strategy look at the + following: + + ..seealso:: + **Original reference**: Pichi, Federico, Beatriz Moya, and Jan S. + Hesthaven. "A graph convolutional autoencoder approach to model order reduction - for parametrized PDEs." Journal of - Computational Physics 501 (2024): 112762. - DOI - `10.1016/j.jcp.2024.112762 `_. + for parametrized PDEs." + Journal of Computational Physics 501 (2024): 112762. + DOI `10.1016/j.jcp.2024.112762 + `_. .. warning:: This solver works only for data-driven model. Hence in the ``problem`` definition the codition must only contain ``input`` (e.g. coefficient parameters, time parameters), and ``target``. - - .. warning:: - This solver does not currently support the possibility to pass - ``extra_feature``. """ def __init__( @@ -96,22 +91,28 @@ class ReducedOrderModelSolver(SupervisedSolver): use_lt=True, ): """ + Initialization of the :class:`ReducedOrderModelSolver` class. + :param AbstractProblem problem: The formualation of the problem. :param torch.nn.Module reduction_network: The reduction network used - for reducing the input space. It must contain two methods, - namely ``encode`` for input encoding and ``decode`` for decoding the + for reducing the input space. It must contain two methods, namely + ``encode`` for input encoding, and ``decode`` for decoding the former result. :param torch.nn.Module interpolation_network: The interpolation network - for interpolating the control parameters to latent space obtain by + for interpolating the control parameters to latent space obtained by the ``reduction_network`` encoding. - :param torch.nn.Module loss: The loss function used as minimizer, - default :class:`torch.nn.MSELoss`. - :param torch.optim.Optimizer optimizer: The neural network optimizer to - use; default is :class:`torch.optim.Adam`. - :param torch.optim.LRScheduler scheduler: Learning - rate scheduler. - :param WeightingInterface weighting: The loss weighting to use. - :param bool use_lt: Using LabelTensors as input during training. + :param torch.nn.Module loss: The loss function to be minimized. + If `None`, the :class:`torch.nn.MSELoss` loss is used. + Default is `None`. + :param Optimizer optimizer: The optimizer to be used. + If `None`, the :class:`torch.optim.Adam`. optimizer is used. + Default is ``None``. + :param Scheduler scheduler: Learning rate scheduler. If `None`, + the constant learning rate scheduler is used. Default is ``None``. + :param WeightingInterface weighting: The weighting schema to be used. + If `None`, no weighting schema is used. Default is ``None``. + :param bool use_lt: If ``True``, the solver uses LabelTensors as input. + Default is ``True``. """ model = torch.nn.ModuleDict( { @@ -146,10 +147,10 @@ class ReducedOrderModelSolver(SupervisedSolver): def forward(self, x): """ - Forward pass implementation for the solver. It finds the encoder - representation by calling ``interpolation_network.forward`` on the - input, and maps this representation to output space by calling - ``reduction_network.decode``. + Forward pass implementation. + It computes the encoder representation by calling the forward method + of the ``interpolation_network`` on the input, and maps it to output + space by calling the decode methode of the ``reduction_network``. :param torch.Tensor x: Input tensor. :return: Solver solution. @@ -161,15 +162,14 @@ class ReducedOrderModelSolver(SupervisedSolver): def loss_data(self, input_pts, output_pts): """ - The data loss for the ReducedOrderModelSolver solver. - It computes the loss between - the network output against the true solution. This function - should not be override if not intentionally. + Compute the data loss by evaluating the loss between the network's + output and the true solution. This method should not be overridden, if + not intentionally. - :param LabelTensor input_tensor: The input to the neural networks. - :param LabelTensor output_tensor: The true solution to compare the - network solution. - :return: The residual loss averaged on the input coordinates + :param LabelTensor input_pts: The input points to the neural network. + :param LabelTensor output_pts: The true solution to compare with the + network's output. + :return: The supervised loss, averaged over the number of observations. :rtype: torch.Tensor """ # extract networks diff --git a/pina/solver/solver.py b/pina/solver/solver.py index f671a7e..cc49c34 100644 --- a/pina/solver/solver.py +++ b/pina/solver/solver.py @@ -14,17 +14,19 @@ from ..utils import check_consistency, labelize_forward class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): """ - SolverInterface base class. This class is a wrapper of LightningModule. + Abstract base class for PINA solvers. All specific solvers should inherit + from this interface. This class is a wrapper of + :class:`~lightning.pytorch.LightningModule`. """ def __init__(self, problem, weighting, use_lt): """ - :param problem: A problem definition instance. - :type problem: AbstractProblem - :param weighting: The loss weighting to use. - :type weighting: WeightingInterface - :param use_lt: Using LabelTensors as input during training. - :type use_lt: bool + Initialization of the :class:`SolverInterface` class. + + :param AbstractProblem problem: The problem to be solved. + :param WeightingInterface weighting: The weighting schema to be used. + If `None`, no weighting schema is used. Default is ``None``. + :param bool use_lt: If ``True``, the solver uses LabelTensors as input. """ super().__init__() @@ -59,22 +61,24 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): self._pina_schedulers = None def _check_solver_consistency(self, problem): + """ + Check the consistency of the solver with the problem formulation. + + :param AbstractProblem problem: The problem to be solved. + """ for condition in problem.conditions.values(): check_consistency(condition, self.accepted_conditions_types) def _optimization_cycle(self, batch): """ - Perform a private optimization cycle by computing the loss for each - condition in the given batch. The loss are later aggregated using the - specific weighting schema. + Aggregate the loss for each condition in the batch. - :param batch: A batch of data, where each element is a tuple containing - a condition name and a dictionary of points. - :type batch: list of tuples (str, dict) - :return: The computed loss for the all conditions in the batch, - cast to a subclass of `torch.Tensor`. It should return a dict - containing the condition name and the associated scalar loss. - :rtype: dict(torch.Tensor) + :param list[tuple[str, dict]] batch: A batch of data. Each element is a + tuple containing a condition name and a dictionary of points. + :return: The computed loss for the all conditions in the batch, casted + to a subclass of `torch.Tensor`. It should return a dict containing + the condition name and the associated scalar loss. + :rtype: dict """ losses = self.optimization_cycle(batch) for name, value in losses.items(): @@ -88,9 +92,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): """ Solver training step. - :param batch: The batch element in the dataloader. - :type batch: tuple - :return: The sum of the loss functions. + :param list[tuple[str, dict]] batch: The batch element in the dataloader. + :return: The loss of the training step. :rtype: LabelTensor """ loss = self._optimization_cycle(batch=batch) @@ -101,8 +104,7 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): """ Solver validation step. - :param batch: The batch element in the dataloader. - :type batch: tuple + :param list[tuple[str, dict]] batch: The batch element in the dataloader. """ loss = self._optimization_cycle(batch=batch) self.store_log("val_loss", loss, self.get_batch_size(batch)) @@ -111,15 +113,18 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): """ Solver test step. - :param batch: The batch element in the dataloader. - :type batch: tuple + :param list[tuple[str, dict]] batch: The batch element in the dataloader. """ loss = self._optimization_cycle(batch=batch) self.store_log("test_loss", loss, self.get_batch_size(batch)) def store_log(self, name, value, batch_size): """ - TODO + Store the log of the solver. + + :param str name: The name of the log. + :param torch.Tensor value: The value of the log. + :param int batch_size: The size of the batch. """ self.log( @@ -132,49 +137,59 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): @abstractmethod def forward(self, *args, **kwargs): """ - TODO + Abstract method for the forward pass implementation. """ @abstractmethod def optimization_cycle(self, batch): """ - Perform an optimization cycle by computing the loss for each condition - in the given batch. + The optimization cycle for the solvers. - :param batch: A batch of data, where each element is a tuple containing - a condition name and a dictionary of points. - :type batch: list of tuples (str, dict) - :return: The computed loss for the all conditions in the batch, - cast to a subclass of `torch.Tensor`. It should return a dict - containing the condition name and the associated scalar loss. - :rtype: dict(torch.Tensor) + :param list[tuple[str, dict]] batch: The batch element in the dataloader. + :return: The computed loss for the all conditions in the batch, casted + to a subclass of `torch.Tensor`. It should return a dict containing + the condition name and the associated scalar loss. + :rtype: dict """ @property def problem(self): """ - The problem formulation. + The problem instance. + + :return: The problem instance. + :rtype: :class:`~pina.problem.abstract_problem.AbstractProblem` """ return self._pina_problem @property def use_lt(self): """ - Using LabelTensor in training. + Using LabelTensors as input during training. + + :return: The use_lt attribute. + :rtype: bool """ return self._use_lt @property def weighting(self): """ - The weighting mechanism. + The weighting schema. + + :return: The weighting schema. + :rtype: :class:`~pina.loss.weighting_interface.WeightingInterface` """ return self._pina_weighting @staticmethod def get_batch_size(batch): """ - TODO + Get the batch size. + + :param list[tuple[str, dict]] batch: The batch element in the dataloader. + :return: The size of the batch. + :rtype: int """ batch_size = 0 @@ -185,23 +200,29 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): @staticmethod def default_torch_optimizer(): """ - TODO - """ + Set the default optimizer to :class:`torch.optim.Adam`. + :return: The default optimizer. + :rtype: Optimizer + """ return TorchOptimizer(torch.optim.Adam, lr=0.001) @staticmethod def default_torch_scheduler(): """ - TODO + Set the default scheduler to + :class:`torch.optim.lr_scheduler.ConstantLR`. + + :return: The default scheduler. + :rtype: Scheduler """ return TorchScheduler(torch.optim.lr_scheduler.ConstantLR) def on_train_start(self): """ - Hook that is called before training begins. - Used to compile the model if the trainer is set to compile. + This method is called at the start of the training process to compile + the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``. """ super().on_train_start() if self.trainer.compile: @@ -209,8 +230,8 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): def on_test_start(self): """ - Hook that is called before training begins. - Used to compile the model if the trainer is set to compile. + This method is called at the start of the test process to compile + the model if the :class:`~pina.trainer.Trainer` ``compile`` is ``True``. """ super().on_train_start() if self.trainer.compile and not self._check_already_compiled(): @@ -218,7 +239,10 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): def _check_already_compiled(self): """ - TODO + Check if the model is already compiled. + + :return: ``True`` if the model is already compiled, ``False`` otherwise. + :rtype: bool """ models = self._pina_models @@ -234,7 +258,12 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): @staticmethod def _perform_compilation(model): """ - TODO + Perform the compilation of the model. + + :param torch.nn.Module model: The model to compile. + :raises Exception: If the compilation fails. + :return: The compiled model. + :rtype: torch.nn.Module """ model_device = next(model.parameters()).device @@ -249,8 +278,9 @@ class SolverInterface(lightning.pytorch.LightningModule, metaclass=ABCMeta): class SingleSolverInterface(SolverInterface, metaclass=ABCMeta): - """TODO""" - + """ + Base class for PINA solvers using a single :class:`torch.nn.Module`. + """ def __init__( self, problem, @@ -261,14 +291,18 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta): use_lt=True, ): """ - :param problem: A problem definition instance. - :type problem: AbstractProblem - :param model: A torch nn.Module instances. - :type model: torch.nn.Module - :param Optimizer optimizers: A neural network optimizers to use. - :param Scheduler optimizers: A neural network scheduler to use. - :param WeightingInterface weighting: The loss weighting to use. - :param bool use_lt: Using LabelTensors as input during training. + Initialization of the :class:`SingleSolverInterface` class. + + :param AbstractProblem problem: The problem to be solved. + :param torch.nn.Module model: The neural network model to be used. + :param Optimizer optimizer: The optimizer to be used. + If `None`, the Adam optimizer is used. Default is ``None``. + :param Scheduler scheduler: The scheduler to be used. + If `None`, the constant learning rate scheduler is used. + Default is ``None``. + :param WeightingInterface weighting: The weighting schema to be used. + If `None`, no weighting schema is used. Default is ``None``. + :param bool use_lt: If ``True``, the solver uses LabelTensors as input. """ if optimizer is None: optimizer = self.default_torch_optimizer() @@ -292,11 +326,12 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta): def forward(self, x): """ - Forward pass implementation for the solver. + Forward pass implementation. - :param torch.Tensor x: Input tensor. + :param x: Input tensor. + :type x: torch.Tensor | LabelTensor :return: Solver solution. - :rtype: torch.Tensor + :rtype: torch.Tensor | LabelTensor """ x = self.model(x) return x @@ -305,7 +340,7 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta): """ Optimizer configuration for the solver. - :return: The optimizers and the schedulers + :return: The optimizer and the scheduler :rtype: tuple(list, list) """ self.optimizer.hook(self.model.parameters()) @@ -313,44 +348,61 @@ class SingleSolverInterface(SolverInterface, metaclass=ABCMeta): return ([self.optimizer.instance], [self.scheduler.instance]) def _compile_model(self): + """ + Compile the model. + """ if isinstance(self._pina_models[0], torch.nn.ModuleDict): self._compile_module_dict() else: self._compile_single_model() def _compile_module_dict(self): + """ + Compile the model if it is a :class:`torch.nn.ModuleDict`. + """ for name, model in self._pina_models[0].items(): self._pina_models[0][name] = self._perform_compilation(model) def _compile_single_model(self): + """ + Compile the model if it is a single :class:`torch.nn.Module`. + """ self._pina_models[0] = self._perform_compilation(self._pina_models[0]) @property def model(self): """ - Model for training. + The model used for training. + + :return: The model used for training. + :rtype: torch.nn.Module """ return self._pina_models[0] @property def scheduler(self): """ - Scheduler for training. + The scheduler used for training. + + :return: The scheduler used for training. + :rtype: Scheduler """ return self._pina_schedulers[0] @property def optimizer(self): """ - Optimizer for training. + The optimizer used for training. + + :return: The optimizer used for training. + :rtype: Optimizer """ return self._pina_optimizers[0] class MultiSolverInterface(SolverInterface, metaclass=ABCMeta): """ - Multiple Solver base class. This class inherits is a wrapper of - SolverInterface class + Base class for PINA solvers using multiple :class:`torch.nn.Module`. """ def __init__( @@ -363,16 +415,22 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta): use_lt=True, ): """ - :param problem: A problem definition instance. - :type problem: AbstractProblem - :param models: Multiple torch nn.Module instances. + Initialization of the :class:`MultiSolverInterface` class. + + :param AbstractProblem problem: The problem to be solved. + :param models: The neural network models to be used. :type model: list[torch.nn.Module] | tuple[torch.nn.Module] - :param list(Optimizer) optimizers: A list of neural network - optimizers to use. - :param list(Scheduler) optimizers: A list of neural network - schedulers to use. - :param WeightingInterface weighting: The loss weighting to use. - :param bool use_lt: Using LabelTensors as input during training. + :param list[Optimizer] optimizers: The optimizers to be used. + If `None`, the Adam optimizer is used for all models. + Default is ``None``. + :param list[Scheduler] schedulers: The schedulers to be used. + If `None`, the constant learning rate scheduler is used for all the + models. Default is ``None``. + :param WeightingInterface weighting: The weighting schema to be used. + If `None`, no weighting schema is used. Default is ``None``. + :param bool use_lt: If ``True``, the solver uses LabelTensors as input. + :raises ValueError: If the models are not a list or tuple with length + greater than one. """ if not isinstance(models, (list, tuple)) or len(models) < 2: raise ValueError( @@ -418,9 +476,10 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta): self._pina_schedulers = schedulers def configure_optimizers(self): - """Optimizer configuration for the solver. + """ + Optimizer configuration for the solver. - :return: The optimizers and the schedulers + :return: The optimizer and the scheduler :rtype: tuple(list, list) """ for optimizer, scheduler, model in zip( @@ -435,6 +494,9 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta): ) def _compile_model(self): + """ + Compile the model. + """ for i, model in enumerate(self._pina_models): if not isinstance(model, torch.nn.ModuleDict): self._pina_models[i] = self._perform_compilation(model) @@ -442,17 +504,29 @@ class MultiSolverInterface(SolverInterface, metaclass=ABCMeta): @property def models(self): """ - The torch model.""" + The models used for training. + + :return: The models used for training. + :rtype: torch.nn.ModuleList + """ return self._pina_models @property def optimizers(self): """ - The torch model.""" + The optimizers used for training. + + :return: The optimizers used for training. + :rtype: list[Optimizer] + """ return self._pina_optimizers @property def schedulers(self): """ - The torch model.""" + The schedulers used for training. + + :return: The schedulers used for training. + :rtype: list[Scheduler] + """ return self._pina_schedulers diff --git a/pina/solver/supervised.py b/pina/solver/supervised.py index 2bfa858..fd93553 100644 --- a/pina/solver/supervised.py +++ b/pina/solver/supervised.py @@ -1,4 +1,4 @@ -"""Module for SupervisedSolver""" +"""Module for the Supervised Solver.""" import torch from torch.nn.modules.loss import _Loss @@ -10,31 +10,28 @@ from ..condition import InputTargetCondition class SupervisedSolver(SingleSolverInterface): r""" - SupervisedSolver solver class. This class implements a SupervisedSolver, + Supervised Solver solver class. This class implements a Supervised Solver, using a user specified ``model`` to solve a specific ``problem``. - The Supervised Solver class aims to find - a map between the input :math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m` - and the output :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`. The input - can be discretised in space (as in :obj:`~pina.solver.rom.ROMe2eSolver`), - or not (e.g. when training Neural Operators). + The Supervised Solver class aims to find a map between the input + :math:`\mathbf{s}:\Omega\rightarrow\mathbb{R}^m` and the output + :math:`\mathbf{u}:\Omega\rightarrow\mathbb{R}^m`. Given a model :math:`\mathcal{M}`, the following loss function is minimized during training: .. math:: \mathcal{L}_{\rm{problem}} = \frac{1}{N}\sum_{i=1}^N - \mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{v}_i)) + \mathcal{L}(\mathbf{u}_i - \mathcal{M}(\mathbf{v}_i)), - where :math:`\mathcal{L}` is a specific loss function, - default Mean Square Error: + where :math:`\mathcal{L}` is a specific loss function, typically the MSE: .. math:: \mathcal{L}(v) = \| v \|^2_2. - In this context :math:`\mathbf{u}_i` and :math:`\mathbf{v}_i` means that - we are seeking to approximate multiple (discretised) functions given - multiple (discretised) input functions. + In this context, :math:`\mathbf{u}_i` and :math:`\mathbf{v}_i` indicates + the will to approximate multiple (discretised) functions given multiple + (discretised) input functions. """ accepted_conditions_types = InputTargetCondition @@ -50,16 +47,22 @@ class SupervisedSolver(SingleSolverInterface): use_lt=True, ): """ - :param AbstractProblem problem: The formualation of the problem. - :param torch.nn.Module model: The neural network model to use. - :param torch.nn.Module loss: The loss function used as minimizer, - default :class:`torch.nn.MSELoss`. - :param torch.optim.Optimizer optimizer: The neural network optimizer to - use; default is :class:`torch.optim.Adam`. - :param torch.optim.LRScheduler scheduler: Learning - rate scheduler. - :param WeightingInterface weighting: The loss weighting to use. - :param bool use_lt: Using LabelTensors as input during training. + Initialization of the :class:`SupervisedSolver` class. + + :param AbstractProblem problem: The problem to be solved. + :param torch.nn.Module model: The neural network model to be used. + :param torch.nn.Module loss: The loss function to be minimized. + If `None`, the Mean Squared Error (MSE) loss is used. + Default is `None`. + :param torch.optim.Optimizer optimizer: The optimizer to be used. + If `None`, the Adam optimizer is used. Default is ``None``. + :param torch.optim.LRScheduler scheduler: Learning rate scheduler. + If `None`, the constant learning rate scheduler is used. + Default is ``None``. + :param WeightingInterface weighting: The weighting schema to be used. + If `None`, no weighting schema is used. Default is ``None``. + :param bool use_lt: If ``True``, the solver uses LabelTensors as input. + Default is ``True``. """ if loss is None: loss = torch.nn.MSELoss() @@ -81,16 +84,13 @@ class SupervisedSolver(SingleSolverInterface): def optimization_cycle(self, batch): """ - Perform an optimization cycle by computing the loss for each condition - in the given batch. + The optimization cycle for the solvers. - :param batch: A batch of data, where each element is a tuple containing - a condition name and a dictionary of points. - :type batch: list of tuples (str, dict) - :return: The computed loss for the all conditions in the batch, - cast to a subclass of `torch.Tensor`. It should return a dict - containing the condition name and the associated scalar loss. - :rtype: dict(torch.Tensor) + :param list[tuple[str, dict]] batch: The batch element in the dataloader. + :return: The computed loss for the all conditions in the batch, casted + to a subclass of `torch.Tensor`. It should return a dict containing + the condition name and the associated scalar loss. + :rtype: dict """ condition_loss = {} for condition_name, points in batch: @@ -105,16 +105,16 @@ class SupervisedSolver(SingleSolverInterface): def loss_data(self, input_pts, output_pts): """ - The data loss for the Supervised solver. It computes the loss between - the network output against the true solution. This function - should not be override if not intentionally. + Compute the data loss for the Supervised solver by evaluating the loss + between the network's output and the true solution. This method should + not be overridden, if not intentionally. - :param input_pts: The input to the neural networks. + :param input_pts: The input points to the neural network. :type input_pts: LabelTensor | torch.Tensor - :param output_pts: The true solution to compare the - network solution. + :param output_pts: The true solution to compare with the network's + output. :type output_pts: LabelTensor | torch.Tensor - :return: The residual loss. + :return: The supervised loss, averaged over the number of observations. :rtype: torch.Tensor """ return self._loss(self.forward(input_pts), output_pts) @@ -122,6 +122,9 @@ class SupervisedSolver(SingleSolverInterface): @property def loss(self): """ - Loss for training. + The loss function to be minimized. + + :return: The loss function to be minimized. + :rtype: torch.nn.Module """ return self._loss