diff --git a/pina/data/data_module.py b/pina/data/data_module.py index e8bf702..f5f7d98 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -176,7 +176,7 @@ class Collator: def _collate_graph_dataset(self, data_list): """ - Function used to collate the data when the dataset is a + Function used to collate data when the dataset is a :class:`~pina.data.dataset.PinaGraphDataset`. :param data_list: Elememts to be collated. @@ -187,7 +187,6 @@ class Collator: :raises RuntimeError: If the data is not a :class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`. """ - if isinstance(data_list[0], LabelTensor): return LabelTensor.cat(data_list) if isinstance(data_list[0], torch.Tensor): @@ -201,14 +200,13 @@ class Collator: def __call__(self, batch): """ - Perform the collation of the data points fetched from the dataset. - The behavoior of the function is set based on the batching strategy - during class initialization. + Perform the collation of data fetched from the dataset. The behavoior + of the function is set based on the batching strategy during class + initialization. :param batch: List of retrieved data or sampled indices. :type batch: list[int] | list[dict] - :return: Dictionary containing the data points fetched from the dataset, - collated. + :return: Dictionary containing colleted data fetched from the dataset. :rtype: dict """ @@ -223,12 +221,10 @@ class PinaSampler: def __new__(cls, dataset, shuffle): """ - Instantiate the sampler based on the environment in which the code is - running. + Instantiate and initialize the sampler. - :param PinaDataset dataset: The dataset to be sampled. - :param bool shuffle: whether to shuffle the dataset or not before - sampling. + :param PinaDataset dataset: The dataset from which to sample. + :param bool shuffle: Whether to shuffle the dataset. :return: The sampler instance. :rtype: torch.utils.data.Sampler """ @@ -267,18 +263,18 @@ class PinaDataModule(LightningDataModule): pin_memory=False, ): """ - Initialize the object, creating datasets based on the input problem. + Initialize the object and creating datasets based on the input problem. :param AbstractProblem problem: The problem containing the data on which to create the datasets and dataloaders. - :param float train_size: Fraction or number of elements in the training - split. It must be in the range [0, 1]. - :param float test_size: Fraction or number of elements in the test - split. It must be in the range [0, 1]. - :param float val_size: Fraction or number of elements in the validation - split. It must be in the range [0, 1]. + :param float train_size: Fraction of elements in the training split. It + must be in the range [0, 1]. + :param float test_size: Fraction of elements in the test split. It must + be in the range [0, 1]. + :param float val_size: Fraction of elements in the validation split. It + must be in the range [0, 1]. :param batch_size: The batch size used for training. If `None`, the - entire dataset is used per batch. + entire dataset is returned in a single batch. :type batch_size: int | None :param bool shuffle: Whether to shuffle the dataset before splitting. Default True. @@ -289,7 +285,7 @@ class PinaDataModule(LightningDataModule): :param int num_workers: Number of worker threads for data loading. Default 0 (serial loading). :param bool pin_memory: Whether to use pinned memory for faster data - transfer to GPU. (Default False). + transfer to GPU. Default False. :raises ValueError: If at least one of the splits is negative. :raises ValueError: If the sum of the splits is different from 1. @@ -370,7 +366,7 @@ class PinaDataModule(LightningDataModule): If the stage is "fit", the training and validation datasets are created. If the stage is "test", the testing dataset is created. - :param str stage: The stage for which to perform the splitting. + :param str stage: The stage for which to perform the dataset setup. :raises ValueError: If the stage is neither "fit" nor "test". """ @@ -534,10 +530,10 @@ class PinaDataModule(LightningDataModule): def find_max_conditions_lengths(self, split): """ - Define the maximum length of the conditions. + Define the maximum length for each conditions. - :param dict split: The splits of the dataset. - :return: The maximum length of the conditions. + :param dict split: The split of the dataset. + :return: The maximum length per condition. :rtype: dict """ diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 9b7ff87..0e24d22 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -75,7 +75,7 @@ class PinaDatasetFactory: class PinaDataset(Dataset, ABC): """ Abstract class for the PINA dataset. It defines the common interface for - the :class:`~pina.data.dataset.PinaTensorDataset` and + :class:`~pina.data.dataset.PinaTensorDataset` and :class:`~pina.data.dataset.PinaGraphDataset` classes. """ @@ -83,9 +83,8 @@ class PinaDataset(Dataset, ABC): self, conditions_dict, max_conditions_lengths, automatic_batching ): """ - Initialize a :class:`~pina.data.dataset.PinaDataset` instance by storing - the providedconditions dictionary, the maximum number of conditions to - consider, and the automatic batching flag. + Initialize :class:`~pina.data.dataset.PinaDataset` instance by storing + the provided conditions dictionary, and the automatic batching flag. :param dict conditions_dict: Dictionary containing the conditions with data.