fix utils and trainer doc

This commit is contained in:
giovanni
2025-03-13 10:47:30 +01:00
committed by Nicola Demo
parent c1be748372
commit df5f6ec2e1
2 changed files with 159 additions and 88 deletions

View File

@@ -1,4 +1,4 @@
"""Trainer module."""
"""Module for the Trainer."""
import sys
import torch
@@ -10,8 +10,11 @@ from .solver import SolverInterface, PINNInterface
class Trainer(lightning.pytorch.Trainer):
"""
PINA custom Trainer class which allows to customize standard Lightning
Trainer class for PINNs training.
PINA custom Trainer class to extend the standard Lightning functionality.
This class enables specific features or behaviors required by the PINA
framework. It modifies the standard :class:`lightning.pytorch.Trainer` class
to better support the training process in PINA.
"""
def __init__(
@@ -29,42 +32,35 @@ class Trainer(lightning.pytorch.Trainer):
**kwargs,
):
"""
Initialize the Trainer class for by calling Lightning costructor and
adding many other functionalities.
Initialization of the :class:`Trainer` class.
:param solver: A pina:class:`SolverInterface` solver for the
differential problem.
:type solver: SolverInterface
:param batch_size: How many samples per batch to load.
If ``batch_size=None`` all
samples are loaded and data are not batched, defaults to None.
:type batch_size: int | None
:param train_size: Percentage of elements in the train dataset.
:type train_size: float
:param test_size: Percentage of elements in the test dataset.
:type test_size: float
:param val_size: Percentage of elements in the val dataset.
:type val_size: float
:param compile: if True model is compiled before training,
default False. For Windows users compilation is always disabled.
:type compile: bool
:param automatic_batching: if True automatic PyTorch batching is
performed. Please avoid using automatic batching when batch_size is
large, default False.
:type automatic_batching: bool
:param num_workers: Number of worker threads for data loading.
Default 0 (serial loading).
:type num_workers: int
:param pin_memory: Whether to use pinned memory for faster data
transfer to GPU. Default False.
:type pin_memory: bool
:param shuffle: Whether to shuffle the data for training. Default True.
:type pin_memory: bool
:param SolverInterface solver: A :class:`~pina.solver.SolverInterface`
solver used to solve a :class:`~pina.problem.AbstractProblem`.
:param int batch_size: The number of samples per batch to load.
If ``None``, all samples are loaded and data is not batched.
Default is ``None``.
:param float train_size: The percentage of elements to include in the
training dataset. Default is ``1.0``.
:param float test_size: The percentage of elements to include in the
test dataset. Default is ``0.0``.
:param float val_size: The percentage of elements to include in the
validation dataset. Default is ``0.0``.
:param bool compile: If ``True``, the model is compiled before training.
Default is ``False``. For Windows users, it is always disabled.
:param bool automatic_batching: If ``True``, automatic PyTorch batching
is performed. Avoid using automatic batching when ``batch_size`` is
large. Default is ``False``.
:param int num_workers: The number of worker threads for data loading.
Default is ``0`` (serial loading).
:param bool pin_memory: Whether to use pinned memory for faster data
transfer to GPU. Default is ``False``.
:param bool shuffle: Whether to shuffle the data during training.
Default is ``True``.
:Keyword Arguments:
The additional keyword arguments specify the training setup
and can be choosen from the `pytorch-lightning
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
Additional keyword arguments that specify the training setup.
These can be selected from the pytorch-lightning Trainer API
<https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>_.
"""
# check consistency for init types
self._check_input_consistency(
@@ -134,6 +130,10 @@ class Trainer(lightning.pytorch.Trainer):
}
def _move_to_device(self):
"""
Moves the ``unknown_parameters`` of an instance of
:class:`~pina.problem.AbstractProblem` to the :class:`Trainer` device.
"""
device = self._accelerator_connector._parallel_devices[0]
# move parameters to device
pb = self.solver.problem
@@ -155,9 +155,25 @@ class Trainer(lightning.pytorch.Trainer):
shuffle,
):
"""
This method is used here because is resampling is needed
during training, there is no need to define to touch the
trainer dataloader, just call the method.
This method is designed to handle the creation of a data module when
resampling is needed during training. Instead of manually defining and
modifying the trainer's dataloaders, this method is called to
automatically configure the data module.
:param float train_size: The percentage of elements to include in the
training dataset.
:param float test_size: The percentage of elements to include in the
test dataset.
:param float val_size: The percentage of elements to include in the
validation dataset.
:param int batch_size: The number of samples per batch to load.
:param bool automatic_batching: Whether to perform automatic batching
with PyTorch.
:param bool pin_memory: Whether to use pinned memory for faster data
transfer to GPU.
:param int num_workers: The number of worker threads for data loading.
:param bool shuffle: Whether to shuffle the data during training.
:raises RuntimeError: If not all conditions are sampled.
"""
if not self.solver.problem.are_all_domains_discretised:
error_message = "\n".join(
@@ -188,25 +204,33 @@ class Trainer(lightning.pytorch.Trainer):
def train(self, **kwargs):
"""
Train the solver method.
Manage the training process of the solver.
"""
return super().fit(self.solver, datamodule=self.data_module, **kwargs)
def test(self, **kwargs):
"""
Test the solver method.
Manage the test process of the solver.
"""
return super().test(self.solver, datamodule=self.data_module, **kwargs)
@property
def solver(self):
"""
Returning trainer solver.
Get the solver.
:return: The solver.
:rtype: SolverInterface
"""
return self._solver
@solver.setter
def solver(self, solver):
"""
Set the solver.
:param SolverInterface solver: The solver to set.
"""
self._solver = solver
@staticmethod
@@ -214,7 +238,18 @@ class Trainer(lightning.pytorch.Trainer):
solver, train_size, test_size, val_size, automatic_batching, compile
):
"""
Check the consistency of the input parameters."
Verifies the consistency of the parameters for the solver configuration.
:param SolverInterface solver: The solver.
:param float train_size: The percentage of elements to include in the
training dataset.
:param float test_size: The percentage of elements to include in the
test dataset.
:param float val_size: The percentage of elements to include in the
validation dataset.
:param bool automatic_batching: Whether to perform automatic batching
with PyTorch.
:param bool compile: If ``True``, the model is compiled before training.
"""
check_consistency(solver, SolverInterface)
@@ -231,8 +266,14 @@ class Trainer(lightning.pytorch.Trainer):
pin_memory, num_workers, shuffle, batch_size
):
"""
Check the consistency of the input parameters and set the default
values.
Checks the consistency of input parameters and sets default values
for missing or invalid parameters.
:param bool pin_memory: Whether to use pinned memory for faster data
transfer to GPU.
:param int num_workers: The number of worker threads for data loading.
:param bool shuffle: Whether to shuffle the data during training.
:param int batch_size: The number of samples per batch to load.
"""
if pin_memory is not None:
check_consistency(pin_memory, bool)