formatting

This commit is contained in:
Dario Coscia
2025-03-17 12:25:04 +01:00
committed by FilippoOlivo
parent 480140dd31
commit ea21754d53
8 changed files with 28 additions and 28 deletions

View File

@@ -20,18 +20,18 @@ class DummyDataloader:
def __init__(self, dataset):
"""
Prepare a dataloader object that returns the entire dataset in a single
batch. Depending on the number of GPUs, the dataset is managed
batch. Depending on the number of GPUs, the dataset is managed
as follows:
- **Distributed Environment** (multiple GPUs): Divides dataset across
processes using the rank and world size. Fetches only portion of
- **Distributed Environment** (multiple GPUs): Divides dataset across
processes using the rank and world size. Fetches only portion of
data corresponding to the current process.
- **Non-Distributed Environment** (single GPU): Fetches the entire
- **Non-Distributed Environment** (single GPU): Fetches the entire
dataset.
:param PinaDataset dataset: The dataset object to be processed.
.. note::
.. note::
This dataloader is used when the batch size is ``None``.
"""
@@ -78,8 +78,8 @@ class Collator:
Initialize the object, setting the collate function based on whether
automatic batching is enabled or not.
:param dict max_conditions_lengths: ``dict`` containing the maximum
number of data points to consider in a single batch for
:param dict max_conditions_lengths: ``dict`` containing the maximum
number of data points to consider in a single batch for
each condition.
:param bool automatic_batching: Whether to enable automatic batching.
:param PinaDataset dataset: The dataset where the data is stored.

View File

@@ -276,7 +276,7 @@ class PinaGraphDataset(PinaDataset):
:param data: List of items to collate in a single batch.
:type data: list[Data] | list[Graph]
:return: Batch object.
:rtype: :class:`~torch_geometric.data.Batch`
:rtype: :class:`~torch_geometric.data.Batch`
| :class:`~pina.graph.LabelBatch`
"""

View File

@@ -389,7 +389,7 @@ class LabelTensor(torch.Tensor):
def requires_grad_(self, mode=True):
"""
Override the :meth:`~torch.Tensor.requires_grad_` method to handle
Override the :meth:`~torch.Tensor.requires_grad_` method to handle
the labels in the new tensor.
For more details, see :meth:`~torch.Tensor.requires_grad_`.

View File

@@ -325,14 +325,14 @@ class FNO(KernelNeuralOperator):
``projection_net`` maps the hidden representation to the output
function.
:param x: The input tensor for performing the computation. Depending
:param x: The input tensor for performing the computation. Depending
on the ``dimensions`` in the initialization, it expects a tensor
with the following shapes:
* 1D tensors: ``[batch, X, channels]``
* 2D tensors: ``[batch, X, Y, channels]``
* 3D tensors: ``[batch, X, Y, Z, channels]``
:type x: torch.Tensor | LabelTensor
:return: The output tensor.
:rtype: torch.Tensor

View File

@@ -18,7 +18,7 @@ class TorchOptimizer(Optimizer):
:param torch.optim.Optimizer optimizer_class: A
:class:`torch.optim.Optimizer` class.
:param dict kwargs: Additional parameters passed to ``optimizer_class``,
see more
see more
`here <https://pytorch.org/docs/stable/optim.html#algorithms>`_.
"""
check_consistency(optimizer_class, torch.optim.Optimizer, subclass=True)

View File

@@ -24,7 +24,7 @@ class TorchScheduler(Scheduler):
:param torch.optim.LRScheduler scheduler_class: A
:class:`torch.optim.LRScheduler` class.
:param dict kwargs: Additional parameters passed to ``scheduler_class``,
see more
see more
`here <https://pytorch.org/docs/stable/optim.html#algorithms>_`.
"""
check_consistency(scheduler_class, LRScheduler, subclass=True)

View File

@@ -178,14 +178,14 @@ class AbstractProblem(metaclass=ABCMeta):
chebyshev sampling, ``chebyshev``; grid sampling ``grid``.
:param domains: The domains from which to sample. Default is ``all``.
:type domains: str | list[str]
:param dict sample_rules: A dictionary defining custom sampling rules
for input variables. If provided, it must contain a dictionary
specifying the sampling rule for each variable, overriding the
``n`` and ``mode`` arguments. Each key must correspond to the
input variables from
:meth:~pina.problem.AbstractProblem.input_variables, and its value
should be another dictionary with
two keys: ``n`` (number of points to sample) and ``mode``
:param dict sample_rules: A dictionary defining custom sampling rules
for input variables. If provided, it must contain a dictionary
specifying the sampling rule for each variable, overriding the
``n`` and ``mode`` arguments. Each key must correspond to the
input variables from
:meth:~pina.problem.AbstractProblem.input_variables, and its value
should be another dictionary with
two keys: ``n`` (number of points to sample) and ``mode``
(sampling method). Defaults to None.
:raises RuntimeError: If both ``n`` and ``sample_rules`` are specified.
:raises RuntimeError: If neither ``n`` nor ``sample_rules`` are set.
@@ -214,8 +214,8 @@ class AbstractProblem(metaclass=ABCMeta):
implemented for :class:`~pina.domain.cartesian.CartesianDomain`.
.. warning::
If custom discretisation is applied by setting ``sample_rules`` not
to ``None``, then the discretised domain must be of class
If custom discretisation is applied by setting ``sample_rules`` not
to ``None``, then the discretised domain must be of class
:class:`~pina.domain.cartesian.CartesianDomain`
"""

View File

@@ -13,8 +13,8 @@ class Trainer(lightning.pytorch.Trainer):
PINA custom Trainer class to extend the standard Lightning functionality.
This class enables specific features or behaviors required by the PINA
framework. It modifies the standard
:class:`lightning.pytorch.Trainer <lightning.pytorch.trainer.trainer.Trainer>`
framework. It modifies the standard
:class:`lightning.pytorch.Trainer <lightning.pytorch.trainer.trainer.Trainer>`
class to better support the training process in PINA.
"""
@@ -209,7 +209,7 @@ class Trainer(lightning.pytorch.Trainer):
Manage the training process of the solver.
:param dict kwargs: Additional keyword arguments. See `pytorch-lightning
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
for details.
"""
return super().fit(self.solver, datamodule=self.data_module, **kwargs)
@@ -219,7 +219,7 @@ class Trainer(lightning.pytorch.Trainer):
Manage the test process of the solver.
:param dict kwargs: Additional keyword arguments. See `pytorch-lightning
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
for details.
"""
return super().test(self.solver, datamodule=self.data_module, **kwargs)