fix utils and trainer doc
This commit is contained in:
131
pina/trainer.py
131
pina/trainer.py
@@ -1,4 +1,4 @@
|
|||||||
"""Trainer module."""
|
"""Module for the Trainer."""
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import torch
|
import torch
|
||||||
@@ -10,8 +10,11 @@ from .solver import SolverInterface, PINNInterface
|
|||||||
|
|
||||||
class Trainer(lightning.pytorch.Trainer):
|
class Trainer(lightning.pytorch.Trainer):
|
||||||
"""
|
"""
|
||||||
PINA custom Trainer class which allows to customize standard Lightning
|
PINA custom Trainer class to extend the standard Lightning functionality.
|
||||||
Trainer class for PINNs training.
|
|
||||||
|
This class enables specific features or behaviors required by the PINA
|
||||||
|
framework. It modifies the standard :class:`lightning.pytorch.Trainer` class
|
||||||
|
to better support the training process in PINA.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -29,42 +32,35 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Initialize the Trainer class for by calling Lightning costructor and
|
Initialization of the :class:`Trainer` class.
|
||||||
adding many other functionalities.
|
|
||||||
|
|
||||||
:param solver: A pina:class:`SolverInterface` solver for the
|
:param SolverInterface solver: A :class:`~pina.solver.SolverInterface`
|
||||||
differential problem.
|
solver used to solve a :class:`~pina.problem.AbstractProblem`.
|
||||||
:type solver: SolverInterface
|
:param int batch_size: The number of samples per batch to load.
|
||||||
:param batch_size: How many samples per batch to load.
|
If ``None``, all samples are loaded and data is not batched.
|
||||||
If ``batch_size=None`` all
|
Default is ``None``.
|
||||||
samples are loaded and data are not batched, defaults to None.
|
:param float train_size: The percentage of elements to include in the
|
||||||
:type batch_size: int | None
|
training dataset. Default is ``1.0``.
|
||||||
:param train_size: Percentage of elements in the train dataset.
|
:param float test_size: The percentage of elements to include in the
|
||||||
:type train_size: float
|
test dataset. Default is ``0.0``.
|
||||||
:param test_size: Percentage of elements in the test dataset.
|
:param float val_size: The percentage of elements to include in the
|
||||||
:type test_size: float
|
validation dataset. Default is ``0.0``.
|
||||||
:param val_size: Percentage of elements in the val dataset.
|
:param bool compile: If ``True``, the model is compiled before training.
|
||||||
:type val_size: float
|
Default is ``False``. For Windows users, it is always disabled.
|
||||||
:param compile: if True model is compiled before training,
|
:param bool automatic_batching: If ``True``, automatic PyTorch batching
|
||||||
default False. For Windows users compilation is always disabled.
|
is performed. Avoid using automatic batching when ``batch_size`` is
|
||||||
:type compile: bool
|
large. Default is ``False``.
|
||||||
:param automatic_batching: if True automatic PyTorch batching is
|
:param int num_workers: The number of worker threads for data loading.
|
||||||
performed. Please avoid using automatic batching when batch_size is
|
Default is ``0`` (serial loading).
|
||||||
large, default False.
|
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||||
:type automatic_batching: bool
|
transfer to GPU. Default is ``False``.
|
||||||
:param num_workers: Number of worker threads for data loading.
|
:param bool shuffle: Whether to shuffle the data during training.
|
||||||
Default 0 (serial loading).
|
Default is ``True``.
|
||||||
:type num_workers: int
|
|
||||||
:param pin_memory: Whether to use pinned memory for faster data
|
|
||||||
transfer to GPU. Default False.
|
|
||||||
:type pin_memory: bool
|
|
||||||
:param shuffle: Whether to shuffle the data for training. Default True.
|
|
||||||
:type pin_memory: bool
|
|
||||||
|
|
||||||
:Keyword Arguments:
|
:Keyword Arguments:
|
||||||
The additional keyword arguments specify the training setup
|
Additional keyword arguments that specify the training setup.
|
||||||
and can be choosen from the `pytorch-lightning
|
These can be selected from the pytorch-lightning Trainer API
|
||||||
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
|
<https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>_.
|
||||||
"""
|
"""
|
||||||
# check consistency for init types
|
# check consistency for init types
|
||||||
self._check_input_consistency(
|
self._check_input_consistency(
|
||||||
@@ -134,6 +130,10 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _move_to_device(self):
|
def _move_to_device(self):
|
||||||
|
"""
|
||||||
|
Moves the ``unknown_parameters`` of an instance of
|
||||||
|
:class:`~pina.problem.AbstractProblem` to the :class:`Trainer` device.
|
||||||
|
"""
|
||||||
device = self._accelerator_connector._parallel_devices[0]
|
device = self._accelerator_connector._parallel_devices[0]
|
||||||
# move parameters to device
|
# move parameters to device
|
||||||
pb = self.solver.problem
|
pb = self.solver.problem
|
||||||
@@ -155,9 +155,25 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
shuffle,
|
shuffle,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This method is used here because is resampling is needed
|
This method is designed to handle the creation of a data module when
|
||||||
during training, there is no need to define to touch the
|
resampling is needed during training. Instead of manually defining and
|
||||||
trainer dataloader, just call the method.
|
modifying the trainer's dataloaders, this method is called to
|
||||||
|
automatically configure the data module.
|
||||||
|
|
||||||
|
:param float train_size: The percentage of elements to include in the
|
||||||
|
training dataset.
|
||||||
|
:param float test_size: The percentage of elements to include in the
|
||||||
|
test dataset.
|
||||||
|
:param float val_size: The percentage of elements to include in the
|
||||||
|
validation dataset.
|
||||||
|
:param int batch_size: The number of samples per batch to load.
|
||||||
|
:param bool automatic_batching: Whether to perform automatic batching
|
||||||
|
with PyTorch.
|
||||||
|
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||||
|
transfer to GPU.
|
||||||
|
:param int num_workers: The number of worker threads for data loading.
|
||||||
|
:param bool shuffle: Whether to shuffle the data during training.
|
||||||
|
:raises RuntimeError: If not all conditions are sampled.
|
||||||
"""
|
"""
|
||||||
if not self.solver.problem.are_all_domains_discretised:
|
if not self.solver.problem.are_all_domains_discretised:
|
||||||
error_message = "\n".join(
|
error_message = "\n".join(
|
||||||
@@ -188,25 +204,33 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
|
|
||||||
def train(self, **kwargs):
|
def train(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
Train the solver method.
|
Manage the training process of the solver.
|
||||||
"""
|
"""
|
||||||
return super().fit(self.solver, datamodule=self.data_module, **kwargs)
|
return super().fit(self.solver, datamodule=self.data_module, **kwargs)
|
||||||
|
|
||||||
def test(self, **kwargs):
|
def test(self, **kwargs):
|
||||||
"""
|
"""
|
||||||
Test the solver method.
|
Manage the test process of the solver.
|
||||||
"""
|
"""
|
||||||
return super().test(self.solver, datamodule=self.data_module, **kwargs)
|
return super().test(self.solver, datamodule=self.data_module, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def solver(self):
|
def solver(self):
|
||||||
"""
|
"""
|
||||||
Returning trainer solver.
|
Get the solver.
|
||||||
|
|
||||||
|
:return: The solver.
|
||||||
|
:rtype: SolverInterface
|
||||||
"""
|
"""
|
||||||
return self._solver
|
return self._solver
|
||||||
|
|
||||||
@solver.setter
|
@solver.setter
|
||||||
def solver(self, solver):
|
def solver(self, solver):
|
||||||
|
"""
|
||||||
|
Set the solver.
|
||||||
|
|
||||||
|
:param SolverInterface solver: The solver to set.
|
||||||
|
"""
|
||||||
self._solver = solver
|
self._solver = solver
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -214,7 +238,18 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
solver, train_size, test_size, val_size, automatic_batching, compile
|
solver, train_size, test_size, val_size, automatic_batching, compile
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Check the consistency of the input parameters."
|
Verifies the consistency of the parameters for the solver configuration.
|
||||||
|
|
||||||
|
:param SolverInterface solver: The solver.
|
||||||
|
:param float train_size: The percentage of elements to include in the
|
||||||
|
training dataset.
|
||||||
|
:param float test_size: The percentage of elements to include in the
|
||||||
|
test dataset.
|
||||||
|
:param float val_size: The percentage of elements to include in the
|
||||||
|
validation dataset.
|
||||||
|
:param bool automatic_batching: Whether to perform automatic batching
|
||||||
|
with PyTorch.
|
||||||
|
:param bool compile: If ``True``, the model is compiled before training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
check_consistency(solver, SolverInterface)
|
check_consistency(solver, SolverInterface)
|
||||||
@@ -231,8 +266,14 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
pin_memory, num_workers, shuffle, batch_size
|
pin_memory, num_workers, shuffle, batch_size
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Check the consistency of the input parameters and set the default
|
Checks the consistency of input parameters and sets default values
|
||||||
values.
|
for missing or invalid parameters.
|
||||||
|
|
||||||
|
:param bool pin_memory: Whether to use pinned memory for faster data
|
||||||
|
transfer to GPU.
|
||||||
|
:param int num_workers: The number of worker threads for data loading.
|
||||||
|
:param bool shuffle: Whether to shuffle the data during training.
|
||||||
|
:param int batch_size: The number of samples per batch to load.
|
||||||
"""
|
"""
|
||||||
if pin_memory is not None:
|
if pin_memory is not None:
|
||||||
check_consistency(pin_memory, bool)
|
check_consistency(pin_memory, bool)
|
||||||
|
|||||||
116
pina/utils.py
116
pina/utils.py
@@ -1,4 +1,4 @@
|
|||||||
"""Utils module."""
|
"""Module for utility functions."""
|
||||||
|
|
||||||
import types
|
import types
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
@@ -12,14 +12,15 @@ def custom_warning_format(
|
|||||||
message, category, filename, lineno, file=None, line=None
|
message, category, filename, lineno, file=None, line=None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Depewarning custom format.
|
Custom warning formatting function.
|
||||||
|
|
||||||
:param str message: The warning message.
|
:param str message: The warning message.
|
||||||
:param class category: The warning category.
|
:param Warning category: The warning category.
|
||||||
:param str filename: The filename where the warning was raised.
|
:param str filename: The filename where the warning is raised.
|
||||||
:param int lineno: The line number where the warning was raised.
|
:param int lineno: The line number where the warning is raised.
|
||||||
:param str file: The file object where the warning was raised.
|
:param str file: The file object where the warning is raised.
|
||||||
:param inr line: The line where the warning was raised.
|
Default is None.
|
||||||
|
:param int line: The line where the warning is raised.
|
||||||
:return: The formatted warning message.
|
:return: The formatted warning message.
|
||||||
:rtype: str
|
:rtype: str
|
||||||
"""
|
"""
|
||||||
@@ -27,20 +28,20 @@ def custom_warning_format(
|
|||||||
|
|
||||||
|
|
||||||
def check_consistency(object_, object_instance, subclass=False):
|
def check_consistency(object_, object_instance, subclass=False):
|
||||||
"""Helper function to check object inheritance consistency.
|
"""
|
||||||
Given a specific ``'object'`` we check if the object is
|
Check if an object maintains inheritance consistency.
|
||||||
instance of a specific ``'object_instance'``, or in case
|
|
||||||
``'subclass=True'`` we check if the object is subclass
|
|
||||||
if the ``'object_instance'``.
|
|
||||||
|
|
||||||
:param (iterable or class object) object: The object to check the
|
This function checks whether a given object is an instance of a specified
|
||||||
inheritance
|
class or, if ``subclass=True``, whether it is a subclass of the specified
|
||||||
:param Object object_instance: The parent class from where the object
|
class.
|
||||||
is expected to inherit
|
|
||||||
:param str object_name: The name of the object
|
:param object: The object to check.
|
||||||
:param bool subclass: Check if is a subclass and not instance
|
:type object: Iterable | Object
|
||||||
:raises ValueError: If the object does not inherit from the
|
:param Object object_instance: The expected parent class.
|
||||||
specified class
|
:param bool subclass: If True, checks whether ``object_`` is a subclass
|
||||||
|
of ``object_instance`` instead of an instance. Default is ``False``.
|
||||||
|
:raises ValueError: If ``object_`` does not inherit from ``object_instance``
|
||||||
|
as expected.
|
||||||
"""
|
"""
|
||||||
if not isinstance(object_, (list, set, tuple)):
|
if not isinstance(object_, (list, set, tuple)):
|
||||||
object_ = [object_]
|
object_ = [object_]
|
||||||
@@ -59,18 +60,28 @@ def check_consistency(object_, object_instance, subclass=False):
|
|||||||
|
|
||||||
def labelize_forward(forward, input_variables, output_variables):
|
def labelize_forward(forward, input_variables, output_variables):
|
||||||
"""
|
"""
|
||||||
Wrapper decorator to allow users to enable or disable the use of
|
Decorator to enable or disable the use of :class:`~pina.LabelTensor`
|
||||||
LabelTensors during the forward pass.
|
during the forward pass.
|
||||||
|
|
||||||
:param forward: The torch.nn.Module forward function.
|
:param Callable forward: The forward function of a :class:`torch.nn.Module`.
|
||||||
:type forward: Callable
|
:param list[str] input_variables: The names of the input variables of a
|
||||||
:param input_variables: The problem input variables.
|
:class:`~pina.problem.AbstractProblem`.
|
||||||
:type input_variables: list[str] | tuple[str]
|
:param list[str] output_variables: The names of the output variables of a
|
||||||
:param output_variables: The problem output variables.
|
:class:`~pina.problem.AbstractProblem`.
|
||||||
:type output_variables: list[str] | tuple[str]
|
:return: The decorated forward function.
|
||||||
|
:rtype: Callable
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def wrapper(x):
|
def wrapper(x):
|
||||||
|
"""
|
||||||
|
Decorated forward function.
|
||||||
|
|
||||||
|
:param LabelTensor x: The labelized input of the forward pass of an
|
||||||
|
instance of :class:`torch.nn.Module`.
|
||||||
|
:return: The labelized output of the forward pass of an instance of
|
||||||
|
:class:`torch.nn.Module`.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
x = x.extract(input_variables)
|
x = x.extract(input_variables)
|
||||||
output = forward(x)
|
output = forward(x)
|
||||||
# keep it like this, directly using LabelTensor(...) raises errors
|
# keep it like this, directly using LabelTensor(...) raises errors
|
||||||
@@ -82,15 +93,32 @@ def labelize_forward(forward, input_variables, output_variables):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def merge_tensors(tensors): # name to be changed
|
def merge_tensors(tensors):
|
||||||
"""TODO"""
|
"""
|
||||||
|
Merge a list of :class:`~pina.LabelTensor` instances into a single
|
||||||
|
:class:`~pina.LabelTensor` tensor, by applying iteratively the cartesian
|
||||||
|
product.
|
||||||
|
|
||||||
|
:param list[LabelTensor] tensors: The list of tensors to merge.
|
||||||
|
:raises ValueError: If the list of tensors is empty.
|
||||||
|
:return: The merged tensor.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
if tensors:
|
if tensors:
|
||||||
return reduce(merge_two_tensors, tensors[1:], tensors[0])
|
return reduce(merge_two_tensors, tensors[1:], tensors[0])
|
||||||
raise ValueError("Expected at least one tensor")
|
raise ValueError("Expected at least one tensor")
|
||||||
|
|
||||||
|
|
||||||
def merge_two_tensors(tensor1, tensor2):
|
def merge_two_tensors(tensor1, tensor2):
|
||||||
"""TODO"""
|
"""
|
||||||
|
Merge two :class:`~pina.LabelTensor` instances into a single
|
||||||
|
:class:`~pina.LabelTensor` tensor, by applying the cartesian product.
|
||||||
|
|
||||||
|
:param LabelTensor tensor1: The first tensor to merge.
|
||||||
|
:param LabelTensor tensor2: The second tensor to merge.
|
||||||
|
:return: The merged tensor.
|
||||||
|
:rtype: LabelTensor
|
||||||
|
"""
|
||||||
n1 = tensor1.shape[0]
|
n1 = tensor1.shape[0]
|
||||||
n2 = tensor2.shape[0]
|
n2 = tensor2.shape[0]
|
||||||
|
|
||||||
@@ -102,12 +130,14 @@ def merge_two_tensors(tensor1, tensor2):
|
|||||||
|
|
||||||
|
|
||||||
def torch_lhs(n, dim):
|
def torch_lhs(n, dim):
|
||||||
"""Latin Hypercube Sampling torch routine.
|
"""
|
||||||
Sampling in range $[0, 1)^d$.
|
The Latin Hypercube Sampling torch routine, sampling in :math:`[0, 1)`$.
|
||||||
|
|
||||||
:param int n: number of samples
|
:param int n: The number of points to sample.
|
||||||
:param int dim: dimensions of latin hypercube
|
:param int dim: The number of dimensions of the sampling space.
|
||||||
:return: samples
|
:raises TypeError: If `n` or `dim` are not integers.
|
||||||
|
:raises ValueError: If `dim` is less than 1.
|
||||||
|
:return: The sampled points.
|
||||||
:rtype: torch.tensor
|
:rtype: torch.tensor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -137,10 +167,10 @@ def torch_lhs(n, dim):
|
|||||||
|
|
||||||
def is_function(f):
|
def is_function(f):
|
||||||
"""
|
"""
|
||||||
Checks whether the given object `f` is a function or lambda.
|
Check if the given object is a function or a lambda.
|
||||||
|
|
||||||
:param object f: The object to be checked.
|
:param Object f: The object to be checked.
|
||||||
:return: `True` if `f` is a function, `False` otherwise.
|
:return: ``True`` if ``f`` is a function, ``False`` otherwise.
|
||||||
:rtype: bool
|
:rtype: bool
|
||||||
"""
|
"""
|
||||||
return isinstance(f, (types.FunctionType, types.LambdaType))
|
return isinstance(f, (types.FunctionType, types.LambdaType))
|
||||||
@@ -148,11 +178,11 @@ def is_function(f):
|
|||||||
|
|
||||||
def chebyshev_roots(n):
|
def chebyshev_roots(n):
|
||||||
"""
|
"""
|
||||||
Return the roots of *n* Chebyshev polynomials (between [-1, 1]).
|
Compute the roots of the Chebyshev polynomial of degree ``n``.
|
||||||
|
|
||||||
:param int n: number of roots
|
:param int n: The number of roots to return.
|
||||||
:return: roots
|
:return: The roots of the Chebyshev polynomials.
|
||||||
:rtype: torch.tensor
|
:rtype: torch.Tensor
|
||||||
"""
|
"""
|
||||||
pi = torch.acos(torch.zeros(1)).item() * 2
|
pi = torch.acos(torch.zeros(1)).item() * 2
|
||||||
k = torch.arange(n)
|
k = torch.arange(n)
|
||||||
|
|||||||
Reference in New Issue
Block a user