fix utils and trainer doc

This commit is contained in:
giovanni
2025-03-13 10:47:30 +01:00
committed by FilippoOlivo
parent b23c0f186a
commit 6883763949
2 changed files with 159 additions and 88 deletions

View File

@@ -1,4 +1,4 @@
"""Trainer module."""
"""Module for the Trainer."""
import sys
import torch
@@ -10,8 +10,11 @@ from .solver import SolverInterface, PINNInterface
class Trainer(lightning.pytorch.Trainer):
"""
PINA custom Trainer class which allows to customize standard Lightning
Trainer class for PINNs training.
PINA custom Trainer class to extend the standard Lightning functionality.
This class enables specific features or behaviors required by the PINA
framework. It modifies the standard :class:`lightning.pytorch.Trainer` class
to better support the training process in PINA.
"""
def __init__(
@@ -29,42 +32,35 @@ class Trainer(lightning.pytorch.Trainer):
**kwargs,
):
"""
Initialize the Trainer class for by calling Lightning costructor and
adding many other functionalities.
Initialization of the :class:`Trainer` class.
:param solver: A pina:class:`SolverInterface` solver for the
differential problem.
:type solver: SolverInterface
:param batch_size: How many samples per batch to load.
If ``batch_size=None`` all
samples are loaded and data are not batched, defaults to None.
:type batch_size: int | None
:param train_size: Percentage of elements in the train dataset.
:type train_size: float
:param test_size: Percentage of elements in the test dataset.
:type test_size: float
:param val_size: Percentage of elements in the val dataset.
:type val_size: float
:param compile: if True model is compiled before training,
default False. For Windows users compilation is always disabled.
:type compile: bool
:param automatic_batching: if True automatic PyTorch batching is
performed. Please avoid using automatic batching when batch_size is
large, default False.
:type automatic_batching: bool
:param num_workers: Number of worker threads for data loading.
Default 0 (serial loading).
:type num_workers: int
:param pin_memory: Whether to use pinned memory for faster data
transfer to GPU. Default False.
:type pin_memory: bool
:param shuffle: Whether to shuffle the data for training. Default True.
:type pin_memory: bool
:param SolverInterface solver: A :class:`~pina.solver.SolverInterface`
solver used to solve a :class:`~pina.problem.AbstractProblem`.
:param int batch_size: The number of samples per batch to load.
If ``None``, all samples are loaded and data is not batched.
Default is ``None``.
:param float train_size: The percentage of elements to include in the
training dataset. Default is ``1.0``.
:param float test_size: The percentage of elements to include in the
test dataset. Default is ``0.0``.
:param float val_size: The percentage of elements to include in the
validation dataset. Default is ``0.0``.
:param bool compile: If ``True``, the model is compiled before training.
Default is ``False``. For Windows users, it is always disabled.
:param bool automatic_batching: If ``True``, automatic PyTorch batching
is performed. Avoid using automatic batching when ``batch_size`` is
large. Default is ``False``.
:param int num_workers: The number of worker threads for data loading.
Default is ``0`` (serial loading).
:param bool pin_memory: Whether to use pinned memory for faster data
transfer to GPU. Default is ``False``.
:param bool shuffle: Whether to shuffle the data during training.
Default is ``True``.
:Keyword Arguments:
The additional keyword arguments specify the training setup
and can be choosen from the `pytorch-lightning
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
Additional keyword arguments that specify the training setup.
These can be selected from the pytorch-lightning Trainer API
<https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>_.
"""
# check consistency for init types
self._check_input_consistency(
@@ -134,6 +130,10 @@ class Trainer(lightning.pytorch.Trainer):
}
def _move_to_device(self):
"""
Moves the ``unknown_parameters`` of an instance of
:class:`~pina.problem.AbstractProblem` to the :class:`Trainer` device.
"""
device = self._accelerator_connector._parallel_devices[0]
# move parameters to device
pb = self.solver.problem
@@ -155,9 +155,25 @@ class Trainer(lightning.pytorch.Trainer):
shuffle,
):
"""
This method is used here because is resampling is needed
during training, there is no need to define to touch the
trainer dataloader, just call the method.
This method is designed to handle the creation of a data module when
resampling is needed during training. Instead of manually defining and
modifying the trainer's dataloaders, this method is called to
automatically configure the data module.
:param float train_size: The percentage of elements to include in the
training dataset.
:param float test_size: The percentage of elements to include in the
test dataset.
:param float val_size: The percentage of elements to include in the
validation dataset.
:param int batch_size: The number of samples per batch to load.
:param bool automatic_batching: Whether to perform automatic batching
with PyTorch.
:param bool pin_memory: Whether to use pinned memory for faster data
transfer to GPU.
:param int num_workers: The number of worker threads for data loading.
:param bool shuffle: Whether to shuffle the data during training.
:raises RuntimeError: If not all conditions are sampled.
"""
if not self.solver.problem.are_all_domains_discretised:
error_message = "\n".join(
@@ -188,25 +204,33 @@ class Trainer(lightning.pytorch.Trainer):
def train(self, **kwargs):
"""
Train the solver method.
Manage the training process of the solver.
"""
return super().fit(self.solver, datamodule=self.data_module, **kwargs)
def test(self, **kwargs):
"""
Test the solver method.
Manage the test process of the solver.
"""
return super().test(self.solver, datamodule=self.data_module, **kwargs)
@property
def solver(self):
"""
Returning trainer solver.
Get the solver.
:return: The solver.
:rtype: SolverInterface
"""
return self._solver
@solver.setter
def solver(self, solver):
"""
Set the solver.
:param SolverInterface solver: The solver to set.
"""
self._solver = solver
@staticmethod
@@ -214,7 +238,18 @@ class Trainer(lightning.pytorch.Trainer):
solver, train_size, test_size, val_size, automatic_batching, compile
):
"""
Check the consistency of the input parameters."
Verifies the consistency of the parameters for the solver configuration.
:param SolverInterface solver: The solver.
:param float train_size: The percentage of elements to include in the
training dataset.
:param float test_size: The percentage of elements to include in the
test dataset.
:param float val_size: The percentage of elements to include in the
validation dataset.
:param bool automatic_batching: Whether to perform automatic batching
with PyTorch.
:param bool compile: If ``True``, the model is compiled before training.
"""
check_consistency(solver, SolverInterface)
@@ -231,8 +266,14 @@ class Trainer(lightning.pytorch.Trainer):
pin_memory, num_workers, shuffle, batch_size
):
"""
Check the consistency of the input parameters and set the default
values.
Checks the consistency of input parameters and sets default values
for missing or invalid parameters.
:param bool pin_memory: Whether to use pinned memory for faster data
transfer to GPU.
:param int num_workers: The number of worker threads for data loading.
:param bool shuffle: Whether to shuffle the data during training.
:param int batch_size: The number of samples per batch to load.
"""
if pin_memory is not None:
check_consistency(pin_memory, bool)

View File

@@ -1,4 +1,4 @@
"""Utils module."""
"""Module for utility functions."""
import types
from functools import reduce
@@ -12,14 +12,15 @@ def custom_warning_format(
message, category, filename, lineno, file=None, line=None
):
"""
Depewarning custom format.
Custom warning formatting function.
:param str message: The warning message.
:param class category: The warning category.
:param str filename: The filename where the warning was raised.
:param int lineno: The line number where the warning was raised.
:param str file: The file object where the warning was raised.
:param inr line: The line where the warning was raised.
:param Warning category: The warning category.
:param str filename: The filename where the warning is raised.
:param int lineno: The line number where the warning is raised.
:param str file: The file object where the warning is raised.
Default is None.
:param int line: The line where the warning is raised.
:return: The formatted warning message.
:rtype: str
"""
@@ -27,20 +28,20 @@ def custom_warning_format(
def check_consistency(object_, object_instance, subclass=False):
"""Helper function to check object inheritance consistency.
Given a specific ``'object'`` we check if the object is
instance of a specific ``'object_instance'``, or in case
``'subclass=True'`` we check if the object is subclass
if the ``'object_instance'``.
"""
Check if an object maintains inheritance consistency.
:param (iterable or class object) object: The object to check the
inheritance
:param Object object_instance: The parent class from where the object
is expected to inherit
:param str object_name: The name of the object
:param bool subclass: Check if is a subclass and not instance
:raises ValueError: If the object does not inherit from the
specified class
This function checks whether a given object is an instance of a specified
class or, if ``subclass=True``, whether it is a subclass of the specified
class.
:param object: The object to check.
:type object: Iterable | Object
:param Object object_instance: The expected parent class.
:param bool subclass: If True, checks whether ``object_`` is a subclass
of ``object_instance`` instead of an instance. Default is ``False``.
:raises ValueError: If ``object_`` does not inherit from ``object_instance``
as expected.
"""
if not isinstance(object_, (list, set, tuple)):
object_ = [object_]
@@ -59,18 +60,28 @@ def check_consistency(object_, object_instance, subclass=False):
def labelize_forward(forward, input_variables, output_variables):
"""
Wrapper decorator to allow users to enable or disable the use of
LabelTensors during the forward pass.
Decorator to enable or disable the use of :class:`~pina.LabelTensor`
during the forward pass.
:param forward: The torch.nn.Module forward function.
:type forward: Callable
:param input_variables: The problem input variables.
:type input_variables: list[str] | tuple[str]
:param output_variables: The problem output variables.
:type output_variables: list[str] | tuple[str]
:param Callable forward: The forward function of a :class:`torch.nn.Module`.
:param list[str] input_variables: The names of the input variables of a
:class:`~pina.problem.AbstractProblem`.
:param list[str] output_variables: The names of the output variables of a
:class:`~pina.problem.AbstractProblem`.
:return: The decorated forward function.
:rtype: Callable
"""
def wrapper(x):
"""
Decorated forward function.
:param LabelTensor x: The labelized input of the forward pass of an
instance of :class:`torch.nn.Module`.
:return: The labelized output of the forward pass of an instance of
:class:`torch.nn.Module`.
:rtype: LabelTensor
"""
x = x.extract(input_variables)
output = forward(x)
# keep it like this, directly using LabelTensor(...) raises errors
@@ -82,15 +93,32 @@ def labelize_forward(forward, input_variables, output_variables):
return wrapper
def merge_tensors(tensors): # name to be changed
"""TODO"""
def merge_tensors(tensors):
"""
Merge a list of :class:`~pina.LabelTensor` instances into a single
:class:`~pina.LabelTensor` tensor, by applying iteratively the cartesian
product.
:param list[LabelTensor] tensors: The list of tensors to merge.
:raises ValueError: If the list of tensors is empty.
:return: The merged tensor.
:rtype: LabelTensor
"""
if tensors:
return reduce(merge_two_tensors, tensors[1:], tensors[0])
raise ValueError("Expected at least one tensor")
def merge_two_tensors(tensor1, tensor2):
"""TODO"""
"""
Merge two :class:`~pina.LabelTensor` instances into a single
:class:`~pina.LabelTensor` tensor, by applying the cartesian product.
:param LabelTensor tensor1: The first tensor to merge.
:param LabelTensor tensor2: The second tensor to merge.
:return: The merged tensor.
:rtype: LabelTensor
"""
n1 = tensor1.shape[0]
n2 = tensor2.shape[0]
@@ -102,12 +130,14 @@ def merge_two_tensors(tensor1, tensor2):
def torch_lhs(n, dim):
"""Latin Hypercube Sampling torch routine.
Sampling in range $[0, 1)^d$.
"""
The Latin Hypercube Sampling torch routine, sampling in :math:`[0, 1)`$.
:param int n: number of samples
:param int dim: dimensions of latin hypercube
:return: samples
:param int n: The number of points to sample.
:param int dim: The number of dimensions of the sampling space.
:raises TypeError: If `n` or `dim` are not integers.
:raises ValueError: If `dim` is less than 1.
:return: The sampled points.
:rtype: torch.tensor
"""
@@ -137,10 +167,10 @@ def torch_lhs(n, dim):
def is_function(f):
"""
Checks whether the given object `f` is a function or lambda.
Check if the given object is a function or a lambda.
:param object f: The object to be checked.
:return: `True` if `f` is a function, `False` otherwise.
:param Object f: The object to be checked.
:return: ``True`` if ``f`` is a function, ``False`` otherwise.
:rtype: bool
"""
return isinstance(f, (types.FunctionType, types.LambdaType))
@@ -148,11 +178,11 @@ def is_function(f):
def chebyshev_roots(n):
"""
Return the roots of *n* Chebyshev polynomials (between [-1, 1]).
Compute the roots of the Chebyshev polynomial of degree ``n``.
:param int n: number of roots
:return: roots
:rtype: torch.tensor
:param int n: The number of roots to return.
:return: The roots of the Chebyshev polynomials.
:rtype: torch.Tensor
"""
pi = torch.acos(torch.zeros(1)).item() * 2
k = torch.arange(n)