diff --git a/pina/trainer.py b/pina/trainer.py index 3fb73bf..37c1c06 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -1,4 +1,4 @@ -"""Trainer module.""" +"""Module for the Trainer.""" import sys import torch @@ -10,8 +10,11 @@ from .solver import SolverInterface, PINNInterface class Trainer(lightning.pytorch.Trainer): """ - PINA custom Trainer class which allows to customize standard Lightning - Trainer class for PINNs training. + PINA custom Trainer class to extend the standard Lightning functionality. + + This class enables specific features or behaviors required by the PINA + framework. It modifies the standard :class:`lightning.pytorch.Trainer` class + to better support the training process in PINA. """ def __init__( @@ -29,42 +32,35 @@ class Trainer(lightning.pytorch.Trainer): **kwargs, ): """ - Initialize the Trainer class for by calling Lightning costructor and - adding many other functionalities. + Initialization of the :class:`Trainer` class. - :param solver: A pina:class:`SolverInterface` solver for the - differential problem. - :type solver: SolverInterface - :param batch_size: How many samples per batch to load. - If ``batch_size=None`` all - samples are loaded and data are not batched, defaults to None. - :type batch_size: int | None - :param train_size: Percentage of elements in the train dataset. - :type train_size: float - :param test_size: Percentage of elements in the test dataset. - :type test_size: float - :param val_size: Percentage of elements in the val dataset. - :type val_size: float - :param compile: if True model is compiled before training, - default False. For Windows users compilation is always disabled. - :type compile: bool - :param automatic_batching: if True automatic PyTorch batching is - performed. Please avoid using automatic batching when batch_size is - large, default False. - :type automatic_batching: bool - :param num_workers: Number of worker threads for data loading. - Default 0 (serial loading). - :type num_workers: int - :param pin_memory: Whether to use pinned memory for faster data - transfer to GPU. Default False. - :type pin_memory: bool - :param shuffle: Whether to shuffle the data for training. Default True. - :type pin_memory: bool + :param SolverInterface solver: A :class:`~pina.solver.SolverInterface` + solver used to solve a :class:`~pina.problem.AbstractProblem`. + :param int batch_size: The number of samples per batch to load. + If ``None``, all samples are loaded and data is not batched. + Default is ``None``. + :param float train_size: The percentage of elements to include in the + training dataset. Default is ``1.0``. + :param float test_size: The percentage of elements to include in the + test dataset. Default is ``0.0``. + :param float val_size: The percentage of elements to include in the + validation dataset. Default is ``0.0``. + :param bool compile: If ``True``, the model is compiled before training. + Default is ``False``. For Windows users, it is always disabled. + :param bool automatic_batching: If ``True``, automatic PyTorch batching + is performed. Avoid using automatic batching when ``batch_size`` is + large. Default is ``False``. + :param int num_workers: The number of worker threads for data loading. + Default is ``0`` (serial loading). + :param bool pin_memory: Whether to use pinned memory for faster data + transfer to GPU. Default is ``False``. + :param bool shuffle: Whether to shuffle the data during training. + Default is ``True``. :Keyword Arguments: - The additional keyword arguments specify the training setup - and can be choosen from the `pytorch-lightning - Trainer API `_ + Additional keyword arguments that specify the training setup. + These can be selected from the pytorch-lightning Trainer API + _. """ # check consistency for init types self._check_input_consistency( @@ -134,6 +130,10 @@ class Trainer(lightning.pytorch.Trainer): } def _move_to_device(self): + """ + Moves the ``unknown_parameters`` of an instance of + :class:`~pina.problem.AbstractProblem` to the :class:`Trainer` device. + """ device = self._accelerator_connector._parallel_devices[0] # move parameters to device pb = self.solver.problem @@ -155,9 +155,25 @@ class Trainer(lightning.pytorch.Trainer): shuffle, ): """ - This method is used here because is resampling is needed - during training, there is no need to define to touch the - trainer dataloader, just call the method. + This method is designed to handle the creation of a data module when + resampling is needed during training. Instead of manually defining and + modifying the trainer's dataloaders, this method is called to + automatically configure the data module. + + :param float train_size: The percentage of elements to include in the + training dataset. + :param float test_size: The percentage of elements to include in the + test dataset. + :param float val_size: The percentage of elements to include in the + validation dataset. + :param int batch_size: The number of samples per batch to load. + :param bool automatic_batching: Whether to perform automatic batching + with PyTorch. + :param bool pin_memory: Whether to use pinned memory for faster data + transfer to GPU. + :param int num_workers: The number of worker threads for data loading. + :param bool shuffle: Whether to shuffle the data during training. + :raises RuntimeError: If not all conditions are sampled. """ if not self.solver.problem.are_all_domains_discretised: error_message = "\n".join( @@ -188,25 +204,33 @@ class Trainer(lightning.pytorch.Trainer): def train(self, **kwargs): """ - Train the solver method. + Manage the training process of the solver. """ return super().fit(self.solver, datamodule=self.data_module, **kwargs) def test(self, **kwargs): """ - Test the solver method. + Manage the test process of the solver. """ return super().test(self.solver, datamodule=self.data_module, **kwargs) @property def solver(self): """ - Returning trainer solver. + Get the solver. + + :return: The solver. + :rtype: SolverInterface """ return self._solver @solver.setter def solver(self, solver): + """ + Set the solver. + + :param SolverInterface solver: The solver to set. + """ self._solver = solver @staticmethod @@ -214,7 +238,18 @@ class Trainer(lightning.pytorch.Trainer): solver, train_size, test_size, val_size, automatic_batching, compile ): """ - Check the consistency of the input parameters." + Verifies the consistency of the parameters for the solver configuration. + + :param SolverInterface solver: The solver. + :param float train_size: The percentage of elements to include in the + training dataset. + :param float test_size: The percentage of elements to include in the + test dataset. + :param float val_size: The percentage of elements to include in the + validation dataset. + :param bool automatic_batching: Whether to perform automatic batching + with PyTorch. + :param bool compile: If ``True``, the model is compiled before training. """ check_consistency(solver, SolverInterface) @@ -231,8 +266,14 @@ class Trainer(lightning.pytorch.Trainer): pin_memory, num_workers, shuffle, batch_size ): """ - Check the consistency of the input parameters and set the default - values. + Checks the consistency of input parameters and sets default values + for missing or invalid parameters. + + :param bool pin_memory: Whether to use pinned memory for faster data + transfer to GPU. + :param int num_workers: The number of worker threads for data loading. + :param bool shuffle: Whether to shuffle the data during training. + :param int batch_size: The number of samples per batch to load. """ if pin_memory is not None: check_consistency(pin_memory, bool) diff --git a/pina/utils.py b/pina/utils.py index 529c98b..123e3e4 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -1,4 +1,4 @@ -"""Utils module.""" +"""Module for utility functions.""" import types from functools import reduce @@ -12,14 +12,15 @@ def custom_warning_format( message, category, filename, lineno, file=None, line=None ): """ - Depewarning custom format. + Custom warning formatting function. :param str message: The warning message. - :param class category: The warning category. - :param str filename: The filename where the warning was raised. - :param int lineno: The line number where the warning was raised. - :param str file: The file object where the warning was raised. - :param inr line: The line where the warning was raised. + :param Warning category: The warning category. + :param str filename: The filename where the warning is raised. + :param int lineno: The line number where the warning is raised. + :param str file: The file object where the warning is raised. + Default is None. + :param int line: The line where the warning is raised. :return: The formatted warning message. :rtype: str """ @@ -27,20 +28,20 @@ def custom_warning_format( def check_consistency(object_, object_instance, subclass=False): - """Helper function to check object inheritance consistency. - Given a specific ``'object'`` we check if the object is - instance of a specific ``'object_instance'``, or in case - ``'subclass=True'`` we check if the object is subclass - if the ``'object_instance'``. + """ + Check if an object maintains inheritance consistency. - :param (iterable or class object) object: The object to check the - inheritance - :param Object object_instance: The parent class from where the object - is expected to inherit - :param str object_name: The name of the object - :param bool subclass: Check if is a subclass and not instance - :raises ValueError: If the object does not inherit from the - specified class + This function checks whether a given object is an instance of a specified + class or, if ``subclass=True``, whether it is a subclass of the specified + class. + + :param object: The object to check. + :type object: Iterable | Object + :param Object object_instance: The expected parent class. + :param bool subclass: If True, checks whether ``object_`` is a subclass + of ``object_instance`` instead of an instance. Default is ``False``. + :raises ValueError: If ``object_`` does not inherit from ``object_instance`` + as expected. """ if not isinstance(object_, (list, set, tuple)): object_ = [object_] @@ -59,18 +60,28 @@ def check_consistency(object_, object_instance, subclass=False): def labelize_forward(forward, input_variables, output_variables): """ - Wrapper decorator to allow users to enable or disable the use of - LabelTensors during the forward pass. + Decorator to enable or disable the use of :class:`~pina.LabelTensor` + during the forward pass. - :param forward: The torch.nn.Module forward function. - :type forward: Callable - :param input_variables: The problem input variables. - :type input_variables: list[str] | tuple[str] - :param output_variables: The problem output variables. - :type output_variables: list[str] | tuple[str] + :param Callable forward: The forward function of a :class:`torch.nn.Module`. + :param list[str] input_variables: The names of the input variables of a + :class:`~pina.problem.AbstractProblem`. + :param list[str] output_variables: The names of the output variables of a + :class:`~pina.problem.AbstractProblem`. + :return: The decorated forward function. + :rtype: Callable """ def wrapper(x): + """ + Decorated forward function. + + :param LabelTensor x: The labelized input of the forward pass of an + instance of :class:`torch.nn.Module`. + :return: The labelized output of the forward pass of an instance of + :class:`torch.nn.Module`. + :rtype: LabelTensor + """ x = x.extract(input_variables) output = forward(x) # keep it like this, directly using LabelTensor(...) raises errors @@ -82,15 +93,32 @@ def labelize_forward(forward, input_variables, output_variables): return wrapper -def merge_tensors(tensors): # name to be changed - """TODO""" +def merge_tensors(tensors): + """ + Merge a list of :class:`~pina.LabelTensor` instances into a single + :class:`~pina.LabelTensor` tensor, by applying iteratively the cartesian + product. + + :param list[LabelTensor] tensors: The list of tensors to merge. + :raises ValueError: If the list of tensors is empty. + :return: The merged tensor. + :rtype: LabelTensor + """ if tensors: return reduce(merge_two_tensors, tensors[1:], tensors[0]) raise ValueError("Expected at least one tensor") def merge_two_tensors(tensor1, tensor2): - """TODO""" + """ + Merge two :class:`~pina.LabelTensor` instances into a single + :class:`~pina.LabelTensor` tensor, by applying the cartesian product. + + :param LabelTensor tensor1: The first tensor to merge. + :param LabelTensor tensor2: The second tensor to merge. + :return: The merged tensor. + :rtype: LabelTensor + """ n1 = tensor1.shape[0] n2 = tensor2.shape[0] @@ -102,12 +130,14 @@ def merge_two_tensors(tensor1, tensor2): def torch_lhs(n, dim): - """Latin Hypercube Sampling torch routine. - Sampling in range $[0, 1)^d$. + """ + The Latin Hypercube Sampling torch routine, sampling in :math:`[0, 1)`$. - :param int n: number of samples - :param int dim: dimensions of latin hypercube - :return: samples + :param int n: The number of points to sample. + :param int dim: The number of dimensions of the sampling space. + :raises TypeError: If `n` or `dim` are not integers. + :raises ValueError: If `dim` is less than 1. + :return: The sampled points. :rtype: torch.tensor """ @@ -137,10 +167,10 @@ def torch_lhs(n, dim): def is_function(f): """ - Checks whether the given object `f` is a function or lambda. + Check if the given object is a function or a lambda. - :param object f: The object to be checked. - :return: `True` if `f` is a function, `False` otherwise. + :param Object f: The object to be checked. + :return: ``True`` if ``f`` is a function, ``False`` otherwise. :rtype: bool """ return isinstance(f, (types.FunctionType, types.LambdaType)) @@ -148,11 +178,11 @@ def is_function(f): def chebyshev_roots(n): """ - Return the roots of *n* Chebyshev polynomials (between [-1, 1]). + Compute the roots of the Chebyshev polynomial of degree ``n``. - :param int n: number of roots - :return: roots - :rtype: torch.tensor + :param int n: The number of roots to return. + :return: The roots of the Chebyshev polynomials. + :rtype: torch.Tensor """ pi = torch.acos(torch.zeros(1)).item() * 2 k = torch.arange(n)