Fix doc data

This commit is contained in:
FilippoOlivo
2025-03-12 23:39:59 +01:00
committed by Nicola Demo
parent d411543b76
commit 7dd954ce50
2 changed files with 24 additions and 29 deletions

View File

@@ -176,7 +176,7 @@ class Collator:
def _collate_graph_dataset(self, data_list): def _collate_graph_dataset(self, data_list):
""" """
Function used to collate the data when the dataset is a Function used to collate data when the dataset is a
:class:`~pina.data.dataset.PinaGraphDataset`. :class:`~pina.data.dataset.PinaGraphDataset`.
:param data_list: Elememts to be collated. :param data_list: Elememts to be collated.
@@ -187,7 +187,6 @@ class Collator:
:raises RuntimeError: If the data is not a :raises RuntimeError: If the data is not a
:class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`. :class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`.
""" """
if isinstance(data_list[0], LabelTensor): if isinstance(data_list[0], LabelTensor):
return LabelTensor.cat(data_list) return LabelTensor.cat(data_list)
if isinstance(data_list[0], torch.Tensor): if isinstance(data_list[0], torch.Tensor):
@@ -201,14 +200,13 @@ class Collator:
def __call__(self, batch): def __call__(self, batch):
""" """
Perform the collation of the data points fetched from the dataset. Perform the collation of data fetched from the dataset. The behavoior
The behavoior of the function is set based on the batching strategy of the function is set based on the batching strategy during class
during class initialization. initialization.
:param batch: List of retrieved data or sampled indices. :param batch: List of retrieved data or sampled indices.
:type batch: list[int] | list[dict] :type batch: list[int] | list[dict]
:return: Dictionary containing the data points fetched from the dataset, :return: Dictionary containing colleted data fetched from the dataset.
collated.
:rtype: dict :rtype: dict
""" """
@@ -223,12 +221,10 @@ class PinaSampler:
def __new__(cls, dataset, shuffle): def __new__(cls, dataset, shuffle):
""" """
Instantiate the sampler based on the environment in which the code is Instantiate and initialize the sampler.
running.
:param PinaDataset dataset: The dataset to be sampled. :param PinaDataset dataset: The dataset from which to sample.
:param bool shuffle: whether to shuffle the dataset or not before :param bool shuffle: Whether to shuffle the dataset.
sampling.
:return: The sampler instance. :return: The sampler instance.
:rtype: torch.utils.data.Sampler :rtype: torch.utils.data.Sampler
""" """
@@ -267,18 +263,18 @@ class PinaDataModule(LightningDataModule):
pin_memory=False, pin_memory=False,
): ):
""" """
Initialize the object, creating datasets based on the input problem. Initialize the object and creating datasets based on the input problem.
:param AbstractProblem problem: The problem containing the data on which :param AbstractProblem problem: The problem containing the data on which
to create the datasets and dataloaders. to create the datasets and dataloaders.
:param float train_size: Fraction or number of elements in the training :param float train_size: Fraction of elements in the training split. It
split. It must be in the range [0, 1]. must be in the range [0, 1].
:param float test_size: Fraction or number of elements in the test :param float test_size: Fraction of elements in the test split. It must
split. It must be in the range [0, 1]. be in the range [0, 1].
:param float val_size: Fraction or number of elements in the validation :param float val_size: Fraction of elements in the validation split. It
split. It must be in the range [0, 1]. must be in the range [0, 1].
:param batch_size: The batch size used for training. If `None`, the :param batch_size: The batch size used for training. If `None`, the
entire dataset is used per batch. entire dataset is returned in a single batch.
:type batch_size: int | None :type batch_size: int | None
:param bool shuffle: Whether to shuffle the dataset before splitting. :param bool shuffle: Whether to shuffle the dataset before splitting.
Default True. Default True.
@@ -289,7 +285,7 @@ class PinaDataModule(LightningDataModule):
:param int num_workers: Number of worker threads for data loading. :param int num_workers: Number of worker threads for data loading.
Default 0 (serial loading). Default 0 (serial loading).
:param bool pin_memory: Whether to use pinned memory for faster data :param bool pin_memory: Whether to use pinned memory for faster data
transfer to GPU. (Default False). transfer to GPU. Default False.
:raises ValueError: If at least one of the splits is negative. :raises ValueError: If at least one of the splits is negative.
:raises ValueError: If the sum of the splits is different from 1. :raises ValueError: If the sum of the splits is different from 1.
@@ -370,7 +366,7 @@ class PinaDataModule(LightningDataModule):
If the stage is "fit", the training and validation datasets are created. If the stage is "fit", the training and validation datasets are created.
If the stage is "test", the testing dataset is created. If the stage is "test", the testing dataset is created.
:param str stage: The stage for which to perform the splitting. :param str stage: The stage for which to perform the dataset setup.
:raises ValueError: If the stage is neither "fit" nor "test". :raises ValueError: If the stage is neither "fit" nor "test".
""" """
@@ -534,10 +530,10 @@ class PinaDataModule(LightningDataModule):
def find_max_conditions_lengths(self, split): def find_max_conditions_lengths(self, split):
""" """
Define the maximum length of the conditions. Define the maximum length for each conditions.
:param dict split: The splits of the dataset. :param dict split: The split of the dataset.
:return: The maximum length of the conditions. :return: The maximum length per condition.
:rtype: dict :rtype: dict
""" """

View File

@@ -75,7 +75,7 @@ class PinaDatasetFactory:
class PinaDataset(Dataset, ABC): class PinaDataset(Dataset, ABC):
""" """
Abstract class for the PINA dataset. It defines the common interface for Abstract class for the PINA dataset. It defines the common interface for
the :class:`~pina.data.dataset.PinaTensorDataset` and :class:`~pina.data.dataset.PinaTensorDataset` and
:class:`~pina.data.dataset.PinaGraphDataset` classes. :class:`~pina.data.dataset.PinaGraphDataset` classes.
""" """
@@ -83,9 +83,8 @@ class PinaDataset(Dataset, ABC):
self, conditions_dict, max_conditions_lengths, automatic_batching self, conditions_dict, max_conditions_lengths, automatic_batching
): ):
""" """
Initialize a :class:`~pina.data.dataset.PinaDataset` instance by storing Initialize :class:`~pina.data.dataset.PinaDataset` instance by storing
the providedconditions dictionary, the maximum number of conditions to the provided conditions dictionary, and the automatic batching flag.
consider, and the automatic batching flag.
:param dict conditions_dict: Dictionary containing the conditions with :param dict conditions_dict: Dictionary containing the conditions with
data. data.