fix utils and trainer doc
This commit is contained in:
116
pina/utils.py
116
pina/utils.py
@@ -1,4 +1,4 @@
|
||||
"""Utils module."""
|
||||
"""Module for utility functions."""
|
||||
|
||||
import types
|
||||
from functools import reduce
|
||||
@@ -12,14 +12,15 @@ def custom_warning_format(
|
||||
message, category, filename, lineno, file=None, line=None
|
||||
):
|
||||
"""
|
||||
Depewarning custom format.
|
||||
Custom warning formatting function.
|
||||
|
||||
:param str message: The warning message.
|
||||
:param class category: The warning category.
|
||||
:param str filename: The filename where the warning was raised.
|
||||
:param int lineno: The line number where the warning was raised.
|
||||
:param str file: The file object where the warning was raised.
|
||||
:param inr line: The line where the warning was raised.
|
||||
:param Warning category: The warning category.
|
||||
:param str filename: The filename where the warning is raised.
|
||||
:param int lineno: The line number where the warning is raised.
|
||||
:param str file: The file object where the warning is raised.
|
||||
Default is None.
|
||||
:param int line: The line where the warning is raised.
|
||||
:return: The formatted warning message.
|
||||
:rtype: str
|
||||
"""
|
||||
@@ -27,20 +28,20 @@ def custom_warning_format(
|
||||
|
||||
|
||||
def check_consistency(object_, object_instance, subclass=False):
|
||||
"""Helper function to check object inheritance consistency.
|
||||
Given a specific ``'object'`` we check if the object is
|
||||
instance of a specific ``'object_instance'``, or in case
|
||||
``'subclass=True'`` we check if the object is subclass
|
||||
if the ``'object_instance'``.
|
||||
"""
|
||||
Check if an object maintains inheritance consistency.
|
||||
|
||||
:param (iterable or class object) object: The object to check the
|
||||
inheritance
|
||||
:param Object object_instance: The parent class from where the object
|
||||
is expected to inherit
|
||||
:param str object_name: The name of the object
|
||||
:param bool subclass: Check if is a subclass and not instance
|
||||
:raises ValueError: If the object does not inherit from the
|
||||
specified class
|
||||
This function checks whether a given object is an instance of a specified
|
||||
class or, if ``subclass=True``, whether it is a subclass of the specified
|
||||
class.
|
||||
|
||||
:param object: The object to check.
|
||||
:type object: Iterable | Object
|
||||
:param Object object_instance: The expected parent class.
|
||||
:param bool subclass: If True, checks whether ``object_`` is a subclass
|
||||
of ``object_instance`` instead of an instance. Default is ``False``.
|
||||
:raises ValueError: If ``object_`` does not inherit from ``object_instance``
|
||||
as expected.
|
||||
"""
|
||||
if not isinstance(object_, (list, set, tuple)):
|
||||
object_ = [object_]
|
||||
@@ -59,18 +60,28 @@ def check_consistency(object_, object_instance, subclass=False):
|
||||
|
||||
def labelize_forward(forward, input_variables, output_variables):
|
||||
"""
|
||||
Wrapper decorator to allow users to enable or disable the use of
|
||||
LabelTensors during the forward pass.
|
||||
Decorator to enable or disable the use of :class:`~pina.LabelTensor`
|
||||
during the forward pass.
|
||||
|
||||
:param forward: The torch.nn.Module forward function.
|
||||
:type forward: Callable
|
||||
:param input_variables: The problem input variables.
|
||||
:type input_variables: list[str] | tuple[str]
|
||||
:param output_variables: The problem output variables.
|
||||
:type output_variables: list[str] | tuple[str]
|
||||
:param Callable forward: The forward function of a :class:`torch.nn.Module`.
|
||||
:param list[str] input_variables: The names of the input variables of a
|
||||
:class:`~pina.problem.AbstractProblem`.
|
||||
:param list[str] output_variables: The names of the output variables of a
|
||||
:class:`~pina.problem.AbstractProblem`.
|
||||
:return: The decorated forward function.
|
||||
:rtype: Callable
|
||||
"""
|
||||
|
||||
def wrapper(x):
|
||||
"""
|
||||
Decorated forward function.
|
||||
|
||||
:param LabelTensor x: The labelized input of the forward pass of an
|
||||
instance of :class:`torch.nn.Module`.
|
||||
:return: The labelized output of the forward pass of an instance of
|
||||
:class:`torch.nn.Module`.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
x = x.extract(input_variables)
|
||||
output = forward(x)
|
||||
# keep it like this, directly using LabelTensor(...) raises errors
|
||||
@@ -82,15 +93,32 @@ def labelize_forward(forward, input_variables, output_variables):
|
||||
return wrapper
|
||||
|
||||
|
||||
def merge_tensors(tensors): # name to be changed
|
||||
"""TODO"""
|
||||
def merge_tensors(tensors):
|
||||
"""
|
||||
Merge a list of :class:`~pina.LabelTensor` instances into a single
|
||||
:class:`~pina.LabelTensor` tensor, by applying iteratively the cartesian
|
||||
product.
|
||||
|
||||
:param list[LabelTensor] tensors: The list of tensors to merge.
|
||||
:raises ValueError: If the list of tensors is empty.
|
||||
:return: The merged tensor.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
if tensors:
|
||||
return reduce(merge_two_tensors, tensors[1:], tensors[0])
|
||||
raise ValueError("Expected at least one tensor")
|
||||
|
||||
|
||||
def merge_two_tensors(tensor1, tensor2):
|
||||
"""TODO"""
|
||||
"""
|
||||
Merge two :class:`~pina.LabelTensor` instances into a single
|
||||
:class:`~pina.LabelTensor` tensor, by applying the cartesian product.
|
||||
|
||||
:param LabelTensor tensor1: The first tensor to merge.
|
||||
:param LabelTensor tensor2: The second tensor to merge.
|
||||
:return: The merged tensor.
|
||||
:rtype: LabelTensor
|
||||
"""
|
||||
n1 = tensor1.shape[0]
|
||||
n2 = tensor2.shape[0]
|
||||
|
||||
@@ -102,12 +130,14 @@ def merge_two_tensors(tensor1, tensor2):
|
||||
|
||||
|
||||
def torch_lhs(n, dim):
|
||||
"""Latin Hypercube Sampling torch routine.
|
||||
Sampling in range $[0, 1)^d$.
|
||||
"""
|
||||
The Latin Hypercube Sampling torch routine, sampling in :math:`[0, 1)`$.
|
||||
|
||||
:param int n: number of samples
|
||||
:param int dim: dimensions of latin hypercube
|
||||
:return: samples
|
||||
:param int n: The number of points to sample.
|
||||
:param int dim: The number of dimensions of the sampling space.
|
||||
:raises TypeError: If `n` or `dim` are not integers.
|
||||
:raises ValueError: If `dim` is less than 1.
|
||||
:return: The sampled points.
|
||||
:rtype: torch.tensor
|
||||
"""
|
||||
|
||||
@@ -137,10 +167,10 @@ def torch_lhs(n, dim):
|
||||
|
||||
def is_function(f):
|
||||
"""
|
||||
Checks whether the given object `f` is a function or lambda.
|
||||
Check if the given object is a function or a lambda.
|
||||
|
||||
:param object f: The object to be checked.
|
||||
:return: `True` if `f` is a function, `False` otherwise.
|
||||
:param Object f: The object to be checked.
|
||||
:return: ``True`` if ``f`` is a function, ``False`` otherwise.
|
||||
:rtype: bool
|
||||
"""
|
||||
return isinstance(f, (types.FunctionType, types.LambdaType))
|
||||
@@ -148,11 +178,11 @@ def is_function(f):
|
||||
|
||||
def chebyshev_roots(n):
|
||||
"""
|
||||
Return the roots of *n* Chebyshev polynomials (between [-1, 1]).
|
||||
Compute the roots of the Chebyshev polynomial of degree ``n``.
|
||||
|
||||
:param int n: number of roots
|
||||
:return: roots
|
||||
:rtype: torch.tensor
|
||||
:param int n: The number of roots to return.
|
||||
:return: The roots of the Chebyshev polynomials.
|
||||
:rtype: torch.Tensor
|
||||
"""
|
||||
pi = torch.acos(torch.zeros(1)).item() * 2
|
||||
k = torch.arange(n)
|
||||
|
||||
Reference in New Issue
Block a user