formatting
This commit is contained in:
committed by
Nicola Demo
parent
2526da36bf
commit
3c301acf18
@@ -20,18 +20,18 @@ class DummyDataloader:
|
||||
def __init__(self, dataset):
|
||||
"""
|
||||
Prepare a dataloader object that returns the entire dataset in a single
|
||||
batch. Depending on the number of GPUs, the dataset is managed
|
||||
batch. Depending on the number of GPUs, the dataset is managed
|
||||
as follows:
|
||||
|
||||
- **Distributed Environment** (multiple GPUs): Divides dataset across
|
||||
processes using the rank and world size. Fetches only portion of
|
||||
- **Distributed Environment** (multiple GPUs): Divides dataset across
|
||||
processes using the rank and world size. Fetches only portion of
|
||||
data corresponding to the current process.
|
||||
- **Non-Distributed Environment** (single GPU): Fetches the entire
|
||||
- **Non-Distributed Environment** (single GPU): Fetches the entire
|
||||
dataset.
|
||||
|
||||
:param PinaDataset dataset: The dataset object to be processed.
|
||||
|
||||
.. note::
|
||||
.. note::
|
||||
This dataloader is used when the batch size is ``None``.
|
||||
"""
|
||||
|
||||
@@ -78,8 +78,8 @@ class Collator:
|
||||
Initialize the object, setting the collate function based on whether
|
||||
automatic batching is enabled or not.
|
||||
|
||||
:param dict max_conditions_lengths: ``dict`` containing the maximum
|
||||
number of data points to consider in a single batch for
|
||||
:param dict max_conditions_lengths: ``dict`` containing the maximum
|
||||
number of data points to consider in a single batch for
|
||||
each condition.
|
||||
:param bool automatic_batching: Whether to enable automatic batching.
|
||||
:param PinaDataset dataset: The dataset where the data is stored.
|
||||
|
||||
@@ -276,7 +276,7 @@ class PinaGraphDataset(PinaDataset):
|
||||
:param data: List of items to collate in a single batch.
|
||||
:type data: list[Data] | list[Graph]
|
||||
:return: Batch object.
|
||||
:rtype: :class:`~torch_geometric.data.Batch`
|
||||
:rtype: :class:`~torch_geometric.data.Batch`
|
||||
| :class:`~pina.graph.LabelBatch`
|
||||
"""
|
||||
|
||||
|
||||
@@ -389,7 +389,7 @@ class LabelTensor(torch.Tensor):
|
||||
|
||||
def requires_grad_(self, mode=True):
|
||||
"""
|
||||
Override the :meth:`~torch.Tensor.requires_grad_` method to handle
|
||||
Override the :meth:`~torch.Tensor.requires_grad_` method to handle
|
||||
the labels in the new tensor.
|
||||
For more details, see :meth:`~torch.Tensor.requires_grad_`.
|
||||
|
||||
|
||||
@@ -325,14 +325,14 @@ class FNO(KernelNeuralOperator):
|
||||
``projection_net`` maps the hidden representation to the output
|
||||
function.
|
||||
|
||||
:param x: The input tensor for performing the computation. Depending
|
||||
:param x: The input tensor for performing the computation. Depending
|
||||
on the ``dimensions`` in the initialization, it expects a tensor
|
||||
with the following shapes:
|
||||
|
||||
|
||||
* 1D tensors: ``[batch, X, channels]``
|
||||
* 2D tensors: ``[batch, X, Y, channels]``
|
||||
* 3D tensors: ``[batch, X, Y, Z, channels]``
|
||||
|
||||
|
||||
:type x: torch.Tensor | LabelTensor
|
||||
:return: The output tensor.
|
||||
:rtype: torch.Tensor
|
||||
|
||||
@@ -18,7 +18,7 @@ class TorchOptimizer(Optimizer):
|
||||
:param torch.optim.Optimizer optimizer_class: A
|
||||
:class:`torch.optim.Optimizer` class.
|
||||
:param dict kwargs: Additional parameters passed to ``optimizer_class``,
|
||||
see more
|
||||
see more
|
||||
`here <https://pytorch.org/docs/stable/optim.html#algorithms>`_.
|
||||
"""
|
||||
check_consistency(optimizer_class, torch.optim.Optimizer, subclass=True)
|
||||
|
||||
@@ -24,7 +24,7 @@ class TorchScheduler(Scheduler):
|
||||
:param torch.optim.LRScheduler scheduler_class: A
|
||||
:class:`torch.optim.LRScheduler` class.
|
||||
:param dict kwargs: Additional parameters passed to ``scheduler_class``,
|
||||
see more
|
||||
see more
|
||||
`here <https://pytorch.org/docs/stable/optim.html#algorithms>_`.
|
||||
"""
|
||||
check_consistency(scheduler_class, LRScheduler, subclass=True)
|
||||
|
||||
@@ -178,14 +178,14 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
chebyshev sampling, ``chebyshev``; grid sampling ``grid``.
|
||||
:param domains: The domains from which to sample. Default is ``all``.
|
||||
:type domains: str | list[str]
|
||||
:param dict sample_rules: A dictionary defining custom sampling rules
|
||||
for input variables. If provided, it must contain a dictionary
|
||||
specifying the sampling rule for each variable, overriding the
|
||||
``n`` and ``mode`` arguments. Each key must correspond to the
|
||||
input variables from
|
||||
:meth:~pina.problem.AbstractProblem.input_variables, and its value
|
||||
should be another dictionary with
|
||||
two keys: ``n`` (number of points to sample) and ``mode``
|
||||
:param dict sample_rules: A dictionary defining custom sampling rules
|
||||
for input variables. If provided, it must contain a dictionary
|
||||
specifying the sampling rule for each variable, overriding the
|
||||
``n`` and ``mode`` arguments. Each key must correspond to the
|
||||
input variables from
|
||||
:meth:~pina.problem.AbstractProblem.input_variables, and its value
|
||||
should be another dictionary with
|
||||
two keys: ``n`` (number of points to sample) and ``mode``
|
||||
(sampling method). Defaults to None.
|
||||
:raises RuntimeError: If both ``n`` and ``sample_rules`` are specified.
|
||||
:raises RuntimeError: If neither ``n`` nor ``sample_rules`` are set.
|
||||
@@ -214,8 +214,8 @@ class AbstractProblem(metaclass=ABCMeta):
|
||||
implemented for :class:`~pina.domain.cartesian.CartesianDomain`.
|
||||
|
||||
.. warning::
|
||||
If custom discretisation is applied by setting ``sample_rules`` not
|
||||
to ``None``, then the discretised domain must be of class
|
||||
If custom discretisation is applied by setting ``sample_rules`` not
|
||||
to ``None``, then the discretised domain must be of class
|
||||
:class:`~pina.domain.cartesian.CartesianDomain`
|
||||
"""
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
PINA custom Trainer class to extend the standard Lightning functionality.
|
||||
|
||||
This class enables specific features or behaviors required by the PINA
|
||||
framework. It modifies the standard
|
||||
:class:`lightning.pytorch.Trainer <lightning.pytorch.trainer.trainer.Trainer>`
|
||||
framework. It modifies the standard
|
||||
:class:`lightning.pytorch.Trainer <lightning.pytorch.trainer.trainer.Trainer>`
|
||||
class to better support the training process in PINA.
|
||||
"""
|
||||
|
||||
@@ -209,7 +209,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
Manage the training process of the solver.
|
||||
|
||||
:param dict kwargs: Additional keyword arguments. See `pytorch-lightning
|
||||
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
|
||||
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
|
||||
for details.
|
||||
"""
|
||||
return super().fit(self.solver, datamodule=self.data_module, **kwargs)
|
||||
@@ -219,7 +219,7 @@ class Trainer(lightning.pytorch.Trainer):
|
||||
Manage the test process of the solver.
|
||||
|
||||
:param dict kwargs: Additional keyword arguments. See `pytorch-lightning
|
||||
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
|
||||
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
|
||||
for details.
|
||||
"""
|
||||
return super().test(self.solver, datamodule=self.data_module, **kwargs)
|
||||
|
||||
Reference in New Issue
Block a user