formatting
This commit is contained in:
committed by
FilippoOlivo
parent
480140dd31
commit
ea21754d53
@@ -20,18 +20,18 @@ class DummyDataloader:
|
|||||||
def __init__(self, dataset):
|
def __init__(self, dataset):
|
||||||
"""
|
"""
|
||||||
Prepare a dataloader object that returns the entire dataset in a single
|
Prepare a dataloader object that returns the entire dataset in a single
|
||||||
batch. Depending on the number of GPUs, the dataset is managed
|
batch. Depending on the number of GPUs, the dataset is managed
|
||||||
as follows:
|
as follows:
|
||||||
|
|
||||||
- **Distributed Environment** (multiple GPUs): Divides dataset across
|
- **Distributed Environment** (multiple GPUs): Divides dataset across
|
||||||
processes using the rank and world size. Fetches only portion of
|
processes using the rank and world size. Fetches only portion of
|
||||||
data corresponding to the current process.
|
data corresponding to the current process.
|
||||||
- **Non-Distributed Environment** (single GPU): Fetches the entire
|
- **Non-Distributed Environment** (single GPU): Fetches the entire
|
||||||
dataset.
|
dataset.
|
||||||
|
|
||||||
:param PinaDataset dataset: The dataset object to be processed.
|
:param PinaDataset dataset: The dataset object to be processed.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
This dataloader is used when the batch size is ``None``.
|
This dataloader is used when the batch size is ``None``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -78,8 +78,8 @@ class Collator:
|
|||||||
Initialize the object, setting the collate function based on whether
|
Initialize the object, setting the collate function based on whether
|
||||||
automatic batching is enabled or not.
|
automatic batching is enabled or not.
|
||||||
|
|
||||||
:param dict max_conditions_lengths: ``dict`` containing the maximum
|
:param dict max_conditions_lengths: ``dict`` containing the maximum
|
||||||
number of data points to consider in a single batch for
|
number of data points to consider in a single batch for
|
||||||
each condition.
|
each condition.
|
||||||
:param bool automatic_batching: Whether to enable automatic batching.
|
:param bool automatic_batching: Whether to enable automatic batching.
|
||||||
:param PinaDataset dataset: The dataset where the data is stored.
|
:param PinaDataset dataset: The dataset where the data is stored.
|
||||||
|
|||||||
@@ -276,7 +276,7 @@ class PinaGraphDataset(PinaDataset):
|
|||||||
:param data: List of items to collate in a single batch.
|
:param data: List of items to collate in a single batch.
|
||||||
:type data: list[Data] | list[Graph]
|
:type data: list[Data] | list[Graph]
|
||||||
:return: Batch object.
|
:return: Batch object.
|
||||||
:rtype: :class:`~torch_geometric.data.Batch`
|
:rtype: :class:`~torch_geometric.data.Batch`
|
||||||
| :class:`~pina.graph.LabelBatch`
|
| :class:`~pina.graph.LabelBatch`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -389,7 +389,7 @@ class LabelTensor(torch.Tensor):
|
|||||||
|
|
||||||
def requires_grad_(self, mode=True):
|
def requires_grad_(self, mode=True):
|
||||||
"""
|
"""
|
||||||
Override the :meth:`~torch.Tensor.requires_grad_` method to handle
|
Override the :meth:`~torch.Tensor.requires_grad_` method to handle
|
||||||
the labels in the new tensor.
|
the labels in the new tensor.
|
||||||
For more details, see :meth:`~torch.Tensor.requires_grad_`.
|
For more details, see :meth:`~torch.Tensor.requires_grad_`.
|
||||||
|
|
||||||
|
|||||||
@@ -325,14 +325,14 @@ class FNO(KernelNeuralOperator):
|
|||||||
``projection_net`` maps the hidden representation to the output
|
``projection_net`` maps the hidden representation to the output
|
||||||
function.
|
function.
|
||||||
|
|
||||||
:param x: The input tensor for performing the computation. Depending
|
:param x: The input tensor for performing the computation. Depending
|
||||||
on the ``dimensions`` in the initialization, it expects a tensor
|
on the ``dimensions`` in the initialization, it expects a tensor
|
||||||
with the following shapes:
|
with the following shapes:
|
||||||
|
|
||||||
* 1D tensors: ``[batch, X, channels]``
|
* 1D tensors: ``[batch, X, channels]``
|
||||||
* 2D tensors: ``[batch, X, Y, channels]``
|
* 2D tensors: ``[batch, X, Y, channels]``
|
||||||
* 3D tensors: ``[batch, X, Y, Z, channels]``
|
* 3D tensors: ``[batch, X, Y, Z, channels]``
|
||||||
|
|
||||||
:type x: torch.Tensor | LabelTensor
|
:type x: torch.Tensor | LabelTensor
|
||||||
:return: The output tensor.
|
:return: The output tensor.
|
||||||
:rtype: torch.Tensor
|
:rtype: torch.Tensor
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ class TorchOptimizer(Optimizer):
|
|||||||
:param torch.optim.Optimizer optimizer_class: A
|
:param torch.optim.Optimizer optimizer_class: A
|
||||||
:class:`torch.optim.Optimizer` class.
|
:class:`torch.optim.Optimizer` class.
|
||||||
:param dict kwargs: Additional parameters passed to ``optimizer_class``,
|
:param dict kwargs: Additional parameters passed to ``optimizer_class``,
|
||||||
see more
|
see more
|
||||||
`here <https://pytorch.org/docs/stable/optim.html#algorithms>`_.
|
`here <https://pytorch.org/docs/stable/optim.html#algorithms>`_.
|
||||||
"""
|
"""
|
||||||
check_consistency(optimizer_class, torch.optim.Optimizer, subclass=True)
|
check_consistency(optimizer_class, torch.optim.Optimizer, subclass=True)
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ class TorchScheduler(Scheduler):
|
|||||||
:param torch.optim.LRScheduler scheduler_class: A
|
:param torch.optim.LRScheduler scheduler_class: A
|
||||||
:class:`torch.optim.LRScheduler` class.
|
:class:`torch.optim.LRScheduler` class.
|
||||||
:param dict kwargs: Additional parameters passed to ``scheduler_class``,
|
:param dict kwargs: Additional parameters passed to ``scheduler_class``,
|
||||||
see more
|
see more
|
||||||
`here <https://pytorch.org/docs/stable/optim.html#algorithms>_`.
|
`here <https://pytorch.org/docs/stable/optim.html#algorithms>_`.
|
||||||
"""
|
"""
|
||||||
check_consistency(scheduler_class, LRScheduler, subclass=True)
|
check_consistency(scheduler_class, LRScheduler, subclass=True)
|
||||||
|
|||||||
@@ -178,14 +178,14 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
chebyshev sampling, ``chebyshev``; grid sampling ``grid``.
|
chebyshev sampling, ``chebyshev``; grid sampling ``grid``.
|
||||||
:param domains: The domains from which to sample. Default is ``all``.
|
:param domains: The domains from which to sample. Default is ``all``.
|
||||||
:type domains: str | list[str]
|
:type domains: str | list[str]
|
||||||
:param dict sample_rules: A dictionary defining custom sampling rules
|
:param dict sample_rules: A dictionary defining custom sampling rules
|
||||||
for input variables. If provided, it must contain a dictionary
|
for input variables. If provided, it must contain a dictionary
|
||||||
specifying the sampling rule for each variable, overriding the
|
specifying the sampling rule for each variable, overriding the
|
||||||
``n`` and ``mode`` arguments. Each key must correspond to the
|
``n`` and ``mode`` arguments. Each key must correspond to the
|
||||||
input variables from
|
input variables from
|
||||||
:meth:~pina.problem.AbstractProblem.input_variables, and its value
|
:meth:~pina.problem.AbstractProblem.input_variables, and its value
|
||||||
should be another dictionary with
|
should be another dictionary with
|
||||||
two keys: ``n`` (number of points to sample) and ``mode``
|
two keys: ``n`` (number of points to sample) and ``mode``
|
||||||
(sampling method). Defaults to None.
|
(sampling method). Defaults to None.
|
||||||
:raises RuntimeError: If both ``n`` and ``sample_rules`` are specified.
|
:raises RuntimeError: If both ``n`` and ``sample_rules`` are specified.
|
||||||
:raises RuntimeError: If neither ``n`` nor ``sample_rules`` are set.
|
:raises RuntimeError: If neither ``n`` nor ``sample_rules`` are set.
|
||||||
@@ -214,8 +214,8 @@ class AbstractProblem(metaclass=ABCMeta):
|
|||||||
implemented for :class:`~pina.domain.cartesian.CartesianDomain`.
|
implemented for :class:`~pina.domain.cartesian.CartesianDomain`.
|
||||||
|
|
||||||
.. warning::
|
.. warning::
|
||||||
If custom discretisation is applied by setting ``sample_rules`` not
|
If custom discretisation is applied by setting ``sample_rules`` not
|
||||||
to ``None``, then the discretised domain must be of class
|
to ``None``, then the discretised domain must be of class
|
||||||
:class:`~pina.domain.cartesian.CartesianDomain`
|
:class:`~pina.domain.cartesian.CartesianDomain`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
PINA custom Trainer class to extend the standard Lightning functionality.
|
PINA custom Trainer class to extend the standard Lightning functionality.
|
||||||
|
|
||||||
This class enables specific features or behaviors required by the PINA
|
This class enables specific features or behaviors required by the PINA
|
||||||
framework. It modifies the standard
|
framework. It modifies the standard
|
||||||
:class:`lightning.pytorch.Trainer <lightning.pytorch.trainer.trainer.Trainer>`
|
:class:`lightning.pytorch.Trainer <lightning.pytorch.trainer.trainer.Trainer>`
|
||||||
class to better support the training process in PINA.
|
class to better support the training process in PINA.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -209,7 +209,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
Manage the training process of the solver.
|
Manage the training process of the solver.
|
||||||
|
|
||||||
:param dict kwargs: Additional keyword arguments. See `pytorch-lightning
|
:param dict kwargs: Additional keyword arguments. See `pytorch-lightning
|
||||||
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
|
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
|
||||||
for details.
|
for details.
|
||||||
"""
|
"""
|
||||||
return super().fit(self.solver, datamodule=self.data_module, **kwargs)
|
return super().fit(self.solver, datamodule=self.data_module, **kwargs)
|
||||||
@@ -219,7 +219,7 @@ class Trainer(lightning.pytorch.Trainer):
|
|||||||
Manage the test process of the solver.
|
Manage the test process of the solver.
|
||||||
|
|
||||||
:param dict kwargs: Additional keyword arguments. See `pytorch-lightning
|
:param dict kwargs: Additional keyword arguments. See `pytorch-lightning
|
||||||
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
|
Trainer API <https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api>`_
|
||||||
for details.
|
for details.
|
||||||
"""
|
"""
|
||||||
return super().test(self.solver, datamodule=self.data_module, **kwargs)
|
return super().test(self.solver, datamodule=self.data_module, **kwargs)
|
||||||
|
|||||||
Reference in New Issue
Block a user