fix utils and trainer doc
This commit is contained in:
131
pina/trainer.py
131
pina/trainer.py
@@ -1,4 +1,4 @@
|
||||
"""Trainer module."""
|
||||
"""Module for the Trainer."""
|
||||
|
||||
import sys
|
||||
import torch
|
||||
@@ -10,8 +10,11 @@ from .solver import SolverInterface, PINNInterface
|
||||
|
||||
class Trainer(lightning.pytorch.Trainer):
|
||||
"""
|
||||
PINA custom Trainer class which allows to customize standard Lightning
|
||||
Trainer class for PINNs training.
|
||||
PINA custom Trainer class to extend the standard Lightning functionality.
|
||||
|
||||
This class enables specific features or behaviors required by the PINA
|
||||
framework. It modifies the standard :class:`lightning.pytorch.Trainer` class
|
||||
to better support the training process in PINA.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -29,42 +32,35 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Initialize the Trainer class for by calling Lightning costructor and
|
||||
adding many other functionalities.
|
||||
Initialization of the :class:`Trainer` class.
|
||||
|
||||
:param solver: A pina:class:`SolverInterface` solver for the
|
||||
differential problem.
|
||||
:type solver: SolverInterface
|
||||
:param batch_size: How many samples per batch to load.
|
||||
If ``batch_size=None`` all
|
||||
samples are loaded and data are not batched, defaults to None.
|
||||
:type batch_size: int | None
|
||||
:param train_size: Percentage of elements in the train dataset.
|
||||
:type train_size: float
|
||||
:param test_size: Percentage of elements in the test dataset.
|
||||
:type test_size: float
|
||||
:param val_size: Percentage of elements in the val dataset.
|
||||
:type val_size: float
|
||||
:param compile: if True model is compiled before training,
|
||||
default False. For Windows users compilation is always disabled.
|
||||
:type compile: bool
|
||||
:param automatic_batching: if True automatic PyTorch batching is
|
||||
performed. Please avoid using automatic batching when batch_size is
|
||||
large, default False.
|
||||
:type automatic_batching: bool
|
||||
:param num_workers: Number of worker threads for data loading.
|
||||
Default 0 (serial loading).
|
||||
:type num_workers: int
|
||||
:param pin_memory: Whether to use pinned memory for faster data
|
||||
transfer to GPU. Default False.
|
||||
:type pin_memory: bool
|
||||
:param shuffle: Whether to shuffle the data for training. Default True.
|
||||
:type pin_memory: bool
|
||||
:param SolverInterface solver: A :class:`~pina.solver.SolverInterface`
|
||||
solver used to solve a :class:`~pina.problem.AbstractProblem`.
|
||||
:param int batch_size: The number of samples per batch to load.
|
||||
If ``None``, all samples are loaded and data is not batched.
|
||||
Default is ``None``.
|
||||
:param float train_size: The percentage of elements to include in the
|
||||
training dataset. Default is ``1.0``.
|
||||
:param float test_size: The percentage of elements to include in the
|
||||
test dataset. Default is ``0.0``.
|
||||
:param float val_size: The percentage of elements to include in the
|
||||
validation dataset. Default is ``0.0``.
|
||||
:param bool compile: If ``True``, the model is compiled before training.
|
||||
Default is ``False``. For Windows users, it is always disabled.
|
||||
:param bool automatic_batching: If ``True``, automatic PyTorch batching
|
||||
is performed. Avoid using automatic batching when ``batch_size`` is
|
||||
large. Default is ``False``.
|
||||
:param int num_workers: The number of worker threads for data loading.
|
||||
Default is ``0`` (serial loading).
|
||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||
transfer to GPU. Default is ``False``.
|
||||
:param bool shuffle: Whether to shuffle the data during training.
|
||||
Default is ``True``.
|
||||
|
||||
:Keyword Arguments:
|
||||
The additional keyword arguments specify the training setup
|
||||
and can be choosen from the `pytorch-lightning
|
||||
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
|
||||
Additional keyword arguments that specify the training setup.
|
||||
These can be selected from the pytorch-lightning Trainer API
|
||||
<https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>_.
|
||||
"""
|
||||
# check consistency for init types
|
||||
self._check_input_consistency(
|
||||
@@ -134,6 +130,10 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
}
|
||||
|
||||
def _move_to_device(self):
|
||||
"""
|
||||
Moves the ``unknown_parameters`` of an instance of
|
||||
:class:`~pina.problem.AbstractProblem` to the :class:`Trainer` device.
|
||||
"""
|
||||
device = self._accelerator_connector._parallel_devices[0]
|
||||
# move parameters to device
|
||||
pb = self.solver.problem
|
||||
@@ -155,9 +155,25 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
shuffle,
|
||||
):
|
||||
"""
|
||||
This method is used here because is resampling is needed
|
||||
during training, there is no need to define to touch the
|
||||
trainer dataloader, just call the method.
|
||||
This method is designed to handle the creation of a data module when
|
||||
resampling is needed during training. Instead of manually defining and
|
||||
modifying the trainer's dataloaders, this method is called to
|
||||
automatically configure the data module.
|
||||
|
||||
:param float train_size: The percentage of elements to include in the
|
||||
training dataset.
|
||||
:param float test_size: The percentage of elements to include in the
|
||||
test dataset.
|
||||
:param float val_size: The percentage of elements to include in the
|
||||
validation dataset.
|
||||
:param int batch_size: The number of samples per batch to load.
|
||||
:param bool automatic_batching: Whether to perform automatic batching
|
||||
with PyTorch.
|
||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||
transfer to GPU.
|
||||
:param int num_workers: The number of worker threads for data loading.
|
||||
:param bool shuffle: Whether to shuffle the data during training.
|
||||
:raises RuntimeError: If not all conditions are sampled.
|
||||
"""
|
||||
if not self.solver.problem.are_all_domains_discretised:
|
||||
error_message = "\n".join(
|
||||
@@ -188,25 +204,33 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
|
||||
def train(self, **kwargs):
|
||||
"""
|
||||
Train the solver method.
|
||||
Manage the training process of the solver.
|
||||
"""
|
||||
return super().fit(self.solver, datamodule=self.data_module, **kwargs)
|
||||
|
||||
def test(self, **kwargs):
|
||||
"""
|
||||
Test the solver method.
|
||||
Manage the test process of the solver.
|
||||
"""
|
||||
return super().test(self.solver, datamodule=self.data_module, **kwargs)
|
||||
|
||||
@property
|
||||
def solver(self):
|
||||
"""
|
||||
Returning trainer solver.
|
||||
Get the solver.
|
||||
|
||||
:return: The solver.
|
||||
:rtype: SolverInterface
|
||||
"""
|
||||
return self._solver
|
||||
|
||||
@solver.setter
|
||||
def solver(self, solver):
|
||||
"""
|
||||
Set the solver.
|
||||
|
||||
:param SolverInterface solver: The solver to set.
|
||||
"""
|
||||
self._solver = solver
|
||||
|
||||
@staticmethod
|
||||
@@ -214,7 +238,18 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
solver, train_size, test_size, val_size, automatic_batching, compile
|
||||
):
|
||||
"""
|
||||
Check the consistency of the input parameters."
|
||||
Verifies the consistency of the parameters for the solver configuration.
|
||||
|
||||
:param SolverInterface solver: The solver.
|
||||
:param float train_size: The percentage of elements to include in the
|
||||
training dataset.
|
||||
:param float test_size: The percentage of elements to include in the
|
||||
test dataset.
|
||||
:param float val_size: The percentage of elements to include in the
|
||||
validation dataset.
|
||||
:param bool automatic_batching: Whether to perform automatic batching
|
||||
with PyTorch.
|
||||
:param bool compile: If ``True``, the model is compiled before training.
|
||||
"""
|
||||
|
||||
check_consistency(solver, SolverInterface)
|
||||
@@ -231,8 +266,14 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
pin_memory, num_workers, shuffle, batch_size
|
||||
):
|
||||
"""
|
||||
Check the consistency of the input parameters and set the default
|
||||
values.
|
||||
Checks the consistency of input parameters and sets default values
|
||||
for missing or invalid parameters.
|
||||
|
||||
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||
transfer to GPU.
|
||||
:param int num_workers: The number of worker threads for data loading.
|
||||
:param bool shuffle: Whether to shuffle the data during training.
|
||||
:param int batch_size: The number of samples per batch to load.
|
||||
"""
|
||||
if pin_memory is not None:
|
||||
check_consistency(pin_memory, bool)
|
||||
|
||||
Reference in New Issue
Block a user