formatting

This commit is contained in:
Dario Coscia
2025-03-17 12:25:04 +01:00
committed by FilippoOlivo
parent 480140dd31
commit ea21754d53
8 changed files with 28 additions and 28 deletions

View File

@@ -20,18 +20,18 @@ class DummyDataloader:
def __init__(self, dataset): def __init__(self, dataset):
""" """
Prepare a dataloader object that returns the entire dataset in a single Prepare a dataloader object that returns the entire dataset in a single
batch. Depending on the number of GPUs, the dataset is managed batch. Depending on the number of GPUs, the dataset is managed
as follows: as follows:
- **Distributed Environment** (multiple GPUs): Divides dataset across - **Distributed Environment** (multiple GPUs): Divides dataset across
processes using the rank and world size. Fetches only portion of processes using the rank and world size. Fetches only portion of
data corresponding to the current process. data corresponding to the current process.
- **Non-Distributed Environment** (single GPU): Fetches the entire - **Non-Distributed Environment** (single GPU): Fetches the entire
dataset. dataset.
:param PinaDataset dataset: The dataset object to be processed. :param PinaDataset dataset: The dataset object to be processed.
.. note:: .. note::
This dataloader is used when the batch size is ``None``. This dataloader is used when the batch size is ``None``.
""" """
@@ -78,8 +78,8 @@ class Collator:
Initialize the object, setting the collate function based on whether Initialize the object, setting the collate function based on whether
automatic batching is enabled or not. automatic batching is enabled or not.
:param dict max_conditions_lengths: ``dict`` containing the maximum :param dict max_conditions_lengths: ``dict`` containing the maximum
number of data points to consider in a single batch for number of data points to consider in a single batch for
each condition. each condition.
:param bool automatic_batching: Whether to enable automatic batching. :param bool automatic_batching: Whether to enable automatic batching.
:param PinaDataset dataset: The dataset where the data is stored. :param PinaDataset dataset: The dataset where the data is stored.

View File

@@ -276,7 +276,7 @@ class PinaGraphDataset(PinaDataset):
:param data: List of items to collate in a single batch. :param data: List of items to collate in a single batch.
:type data: list[Data] | list[Graph] :type data: list[Data] | list[Graph]
:return: Batch object. :return: Batch object.
:rtype: :class:`~torch_geometric.data.Batch` :rtype: :class:`~torch_geometric.data.Batch`
| :class:`~pina.graph.LabelBatch` | :class:`~pina.graph.LabelBatch`
""" """

View File

@@ -389,7 +389,7 @@ class LabelTensor(torch.Tensor):
def requires_grad_(self, mode=True): def requires_grad_(self, mode=True):
""" """
Override the :meth:`~torch.Tensor.requires_grad_` method to handle Override the :meth:`~torch.Tensor.requires_grad_` method to handle
the labels in the new tensor. the labels in the new tensor.
For more details, see :meth:`~torch.Tensor.requires_grad_`. For more details, see :meth:`~torch.Tensor.requires_grad_`.

View File

@@ -325,14 +325,14 @@ class FNO(KernelNeuralOperator):
``projection_net`` maps the hidden representation to the output ``projection_net`` maps the hidden representation to the output
function. function.
:param x: The input tensor for performing the computation. Depending :param x: The input tensor for performing the computation. Depending
on the ``dimensions`` in the initialization, it expects a tensor on the ``dimensions`` in the initialization, it expects a tensor
with the following shapes: with the following shapes:
* 1D tensors: ``[batch, X, channels]`` * 1D tensors: ``[batch, X, channels]``
* 2D tensors: ``[batch, X, Y, channels]`` * 2D tensors: ``[batch, X, Y, channels]``
* 3D tensors: ``[batch, X, Y, Z, channels]`` * 3D tensors: ``[batch, X, Y, Z, channels]``
:type x: torch.Tensor | LabelTensor :type x: torch.Tensor | LabelTensor
:return: The output tensor. :return: The output tensor.
:rtype: torch.Tensor :rtype: torch.Tensor

View File

@@ -18,7 +18,7 @@ class TorchOptimizer(Optimizer):
:param torch.optim.Optimizer optimizer_class: A :param torch.optim.Optimizer optimizer_class: A
:class:`torch.optim.Optimizer` class. :class:`torch.optim.Optimizer` class.
:param dict kwargs: Additional parameters passed to ``optimizer_class``, :param dict kwargs: Additional parameters passed to ``optimizer_class``,
see more see more
`here <https://pytorch.org/docs/stable/optim.html#algorithms>`_. `here <https://pytorch.org/docs/stable/optim.html#algorithms>`_.
""" """
check_consistency(optimizer_class, torch.optim.Optimizer, subclass=True) check_consistency(optimizer_class, torch.optim.Optimizer, subclass=True)

View File

@@ -24,7 +24,7 @@ class TorchScheduler(Scheduler):
:param torch.optim.LRScheduler scheduler_class: A :param torch.optim.LRScheduler scheduler_class: A
:class:`torch.optim.LRScheduler` class. :class:`torch.optim.LRScheduler` class.
:param dict kwargs: Additional parameters passed to ``scheduler_class``, :param dict kwargs: Additional parameters passed to ``scheduler_class``,
see more see more
`here <https://pytorch.org/docs/stable/optim.html#algorithms>_`. `here <https://pytorch.org/docs/stable/optim.html#algorithms>_`.
""" """
check_consistency(scheduler_class, LRScheduler, subclass=True) check_consistency(scheduler_class, LRScheduler, subclass=True)

View File

@@ -178,14 +178,14 @@ class AbstractProblem(metaclass=ABCMeta):
chebyshev sampling, ``chebyshev``; grid sampling ``grid``. chebyshev sampling, ``chebyshev``; grid sampling ``grid``.
:param domains: The domains from which to sample. Default is ``all``. :param domains: The domains from which to sample. Default is ``all``.
:type domains: str | list[str] :type domains: str | list[str]
:param dict sample_rules: A dictionary defining custom sampling rules :param dict sample_rules: A dictionary defining custom sampling rules
for input variables. If provided, it must contain a dictionary for input variables. If provided, it must contain a dictionary
specifying the sampling rule for each variable, overriding the specifying the sampling rule for each variable, overriding the
``n`` and ``mode`` arguments. Each key must correspond to the ``n`` and ``mode`` arguments. Each key must correspond to the
input variables from input variables from
:meth:~pina.problem.AbstractProblem.input_variables, and its value :meth:~pina.problem.AbstractProblem.input_variables, and its value
should be another dictionary with should be another dictionary with
two keys: ``n`` (number of points to sample) and ``mode`` two keys: ``n`` (number of points to sample) and ``mode``
(sampling method). Defaults to None. (sampling method). Defaults to None.
:raises RuntimeError: If both ``n`` and ``sample_rules`` are specified. :raises RuntimeError: If both ``n`` and ``sample_rules`` are specified.
:raises RuntimeError: If neither ``n`` nor ``sample_rules`` are set. :raises RuntimeError: If neither ``n`` nor ``sample_rules`` are set.
@@ -214,8 +214,8 @@ class AbstractProblem(metaclass=ABCMeta):
implemented for :class:`~pina.domain.cartesian.CartesianDomain`. implemented for :class:`~pina.domain.cartesian.CartesianDomain`.
.. warning:: .. warning::
If custom discretisation is applied by setting ``sample_rules`` not If custom discretisation is applied by setting ``sample_rules`` not
to ``None``, then the discretised domain must be of class to ``None``, then the discretised domain must be of class
:class:`~pina.domain.cartesian.CartesianDomain` :class:`~pina.domain.cartesian.CartesianDomain`
""" """

View File

@@ -13,8 +13,8 @@ class Trainer(lightning.pytorch.Trainer):
PINA custom Trainer class to extend the standard Lightning functionality. PINA custom Trainer class to extend the standard Lightning functionality.
This class enables specific features or behaviors required by the PINA This class enables specific features or behaviors required by the PINA
framework. It modifies the standard framework. It modifies the standard
:class:`lightning.pytorch.Trainer <lightning.pytorch.trainer.trainer.Trainer>` :class:`lightning.pytorch.Trainer <lightning.pytorch.trainer.trainer.Trainer>`
class to better support the training process in PINA. class to better support the training process in PINA.
""" """
@@ -209,7 +209,7 @@ class Trainer(lightning.pytorch.Trainer):
Manage the training process of the solver. Manage the training process of the solver.
:param dict kwargs: Additional keyword arguments. See `pytorch-lightning :param dict kwargs: Additional keyword arguments. See `pytorch-lightning
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_ Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
for details. for details.
""" """
return super().fit(self.solver, datamodule=self.data_module, **kwargs) return super().fit(self.solver, datamodule=self.data_module, **kwargs)
@@ -219,7 +219,7 @@ class Trainer(lightning.pytorch.Trainer):
Manage the test process of the solver. Manage the test process of the solver.
:param dict kwargs: Additional keyword arguments. See `pytorch-lightning :param dict kwargs: Additional keyword arguments. See `pytorch-lightning
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_ Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
for details. for details.
""" """
return super().test(self.solver, datamodule=self.data_module, **kwargs) return super().test(self.solver, datamodule=self.data_module, **kwargs)