diff --git a/pina/data/__init__.py b/pina/data/__init__.py index 4c14188..103c144 100644 --- a/pina/data/__init__.py +++ b/pina/data/__init__.py @@ -1,5 +1,5 @@ """ -Import data classes +Module for data data module and dataset. """ __all__ = ["PinaDataModule", "PinaDataset"] diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 2179984..8c9ea12 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -16,16 +16,16 @@ from ..collector import Collector class DummyDataloader: - """ " - Dummy dataloader used when batch size is None. It callects all the data - in self.dataset and returns it when it is called a single batch. + """ + Dataloader used when batch size is ``None``. It returns the entire dataset + in a single batch. """ def __init__(self, dataset): """ Preprare a dataloader object which will return the entire dataset - in a single batch. Depending on the number of GPUs, the dataset we - have the following cases: + in a single batch. Depending on the number of GPUs, the dataset is + managed as follows: - **Distributed Environment** (multiple GPUs): - Divides the dataset across processes using the rank and world @@ -38,7 +38,7 @@ class DummyDataloader: :param dataset: The dataset object to be processed. :type dataset: PinaDataset - .. note:: This data loader is used when the batch size is None. + .. note:: This data loader is used when the batch size is ``None``. """ if ( @@ -72,7 +72,9 @@ class DummyDataloader: class Collator: """ - Class used to collate the batch + This callable class is used to collate the data points fetched from the + dataset. The collation is performed based on the type of dataset used and + on the batching strategy. """ def __init__( @@ -121,7 +123,13 @@ class Collator: def _collate_torch_dataloader(self, batch): """ Function used to collate the batch + + :param list(dict) batch: List of retrieved data. + :return: Dictionary containing the data points fetched from the dataset, + collated. + :rtype: dict """ + batch_dict = {} if isinstance(batch, dict): return batch @@ -149,15 +157,15 @@ class Collator: def _collate_tensor_dataset(data_list): """ Function used to collate the data when the dataset is a - `PinaTensorDataset`. + :class:`PinaTensorDataset`. - :param data_list: List of `torch.Tensor` or `LabelTensor` to be - collated. + :param data_list: Elements to be collated. :type data_list: list(torch.Tensor) | list(LabelTensor) - :raises RuntimeError: If the data is not a `torch.Tensor` or a - `LabelTensor`. - :return: Batch of data + :return: Batch of data. :rtype: dict + + :raises RuntimeError: If the data is not a :class:`torch.Tensor` or a + :class:`LabelTensor`. """ if isinstance(data_list[0], LabelTensor): @@ -169,13 +177,15 @@ class Collator: def _collate_graph_dataset(self, data_list): """ Function used to collate the data when the dataset is a - `PinaGraphDataset`. + :class:`PinaGraphDataset`. - :param data_list: List of `Data` or `Graph` to be collated. - :type data_list: list(Data) | list(Graph) - :raises RuntimeError: If the data is not a `Data` or a `Graph`. - :return: Batch of data + :param data_list: Elememts to be collated. + :type data_list: list(torch_geometric.data.Data) | list(Graph) + :return: Batch of data. :rtype: dict + + :raises RuntimeError: If the data is not a + :class:`torch_geometric.data.Data` or a :class:`Graph`. """ if isinstance(data_list[0], LabelTensor): @@ -184,13 +194,18 @@ class Collator: return torch.cat(data_list) if isinstance(data_list[0], Data): return self.dataset.create_batch(data_list) - raise RuntimeError("Data must be Tensors or LabelTensor or pyG Data") + raise RuntimeError( + "Data must be Tensors or LabelTensor or pyG " + "torch_geometric.data.Data" + ) def __call__(self, batch): """ - Call the function to collate the batch, defined in __init__. + Perform the collation of the data points fetched from the dataset. + The behavoior of the function is set based on the batching strategy + during class initialization. - :param batch: list of indices or list of retrieved data + :param batch: List of retrieved data or sampled indices. :type batch: list(int) | list(dict) :return: Dictionary containing the data points fetched from the dataset, collated. @@ -202,13 +217,14 @@ class Collator: class PinaSampler: """ - Class used to create the sampler instance. + This class is used to create the sampler instance based on the shuffle + parameter and the environment in which the code is running. """ def __new__(cls, dataset, shuffle): """ - Create the sampler instance, according to shuffle and whether the - environment is distributed or not. + Instantiate the sampler based on the environment in which the code is + running. :param PinaDataset dataset: The dataset to be sampled. :param bool shuffle: whether to shuffle the dataset or not before @@ -232,8 +248,9 @@ class PinaSampler: class PinaDataModule(LightningDataModule): """ - This class extend LightningDataModule, allowing proper creation and - management of different types of Datasets defined in PINA + This class extends :class:`pytorch_lightning.LightningDataModule`, + allowing proper creation and management of different types of datasets + defined in PINA. """ def __init__( @@ -253,23 +270,31 @@ class PinaDataModule(LightningDataModule): Initialize the object, creating datasets based on the input problem. :param AbstractProblem problem: The problem containing the data on which - to train/test the model. + to create the datasets and dataloaders. :param float train_size: Fraction or number of elements in the training - split. + split. It must be in the range [0, 1]. :param float test_size: Fraction or number of elements in the test - split. + split. It must be in the range [0, 1]. :param float val_size: Fraction or number of elements in the validation - split. + split. It must be in the range [0, 1]. :param batch_size: The batch size used for training. If `None`, the entire dataset is used per batch. :type batch_size: int | None :param bool shuffle: Whether to shuffle the dataset before splitting. + Default True. :param bool repeat: Whether to repeat the dataset indefinitely. + Default False. :param automatic_batching: Whether to enable automatic batching. + Default False. :param int num_workers: Number of worker threads for data loading. - Default 0 (serial loading) + Default 0 (serial loading). For more information, see + https://pytorch.org/docs/stable/data.html#multi-process-data-loading :param bool pin_memory: Whether to use pinned memory for faster data - transfer to GPU. (Default False) + transfer to GPU. (Default False). For more information, see + https://pytorch.org/docs/stable/data.html#memory-pinning + + :raises ValueError: If at least one of the splits is negative. + :raises ValueError: If the sum of the splits is different from 1. """ super().__init__() @@ -278,6 +303,8 @@ class PinaDataModule(LightningDataModule): self.shuffle = shuffle self.repeat = repeat self.automatic_batching = automatic_batching + + # If batch size is None, num_workers has no effect if batch_size is None and num_workers != 0: warnings.warn( "Setting num_workers when batch_size is None has no effect on " @@ -286,6 +313,8 @@ class PinaDataModule(LightningDataModule): self.num_workers = 0 else: self.num_workers = num_workers + + # If batch size is None, pin_memory has no effect if batch_size is None and pin_memory: warnings.warn( "Setting pin_memory to True has no effect when " @@ -309,16 +338,22 @@ class PinaDataModule(LightningDataModule): splits_dict["train"] = train_size self.train_dataset = None else: + # Use the super method to create the train dataloader which + # raises NotImplementedError self.train_dataloader = super().train_dataloader if test_size > 0: splits_dict["test"] = test_size self.test_dataset = None else: + # Use the super method to create the train dataloader which + # raises NotImplementedError self.test_dataloader = super().test_dataloader if val_size > 0: splits_dict["val"] = val_size self.val_dataset = None else: + # Use the super method to create the train dataloader which + # raises NotImplementedError self.val_dataloader = super().val_dataloader self.collector_splits = self._create_splits(collector, splits_dict) @@ -326,7 +361,13 @@ class PinaDataModule(LightningDataModule): def setup(self, stage=None): """ - Perform the splitting of the dataset + Create the dataset objects for the given stage. + If the stage is "fit", the training and validation datasets are created. + If the stage is "test", the testing dataset is created. + + :param str stage: The stage for which to perform the splitting. + + :raises ValueError: If the stage is neither "fit" nor "test". """ if stage == "fit" or stage is None: self.train_dataset = PinaDatasetFactory( @@ -354,8 +395,18 @@ class PinaDataModule(LightningDataModule): raise ValueError("stage must be either 'fit' or 'test'.") @staticmethod - def _split_condition(condition_dict, splits_dict): - len_condition = len(condition_dict["input"]) + def _split_condition(single_condition_dict, splits_dict): + """ + Split the condition into different stages. + + :param dict single_condition_dict: The condition to be split. + :param dict splits_dict: The dictionary containing the number of + elements in each stage. + :return: A dictionary containing the split condition. + :rtype: dict + """ + + len_condition = len(single_condition_dict["input"]) lengths = [ int(len_condition * length) for length in splits_dict.values() @@ -374,7 +425,7 @@ class PinaDataModule(LightningDataModule): for stage, stage_len in splits_dict.items(): to_return_dict[stage] = { k: v[offset : offset + stage_len] - for k, v in condition_dict.items() + for k, v in single_condition_dict.items() if k != "equation" # Equations are NEVER dataloaded } @@ -386,7 +437,13 @@ class PinaDataModule(LightningDataModule): def _create_splits(self, collector, splits_dict): """ - Create the dataset objects putting data + Create the dataset objects putting data in the correct splits. + + :param Collector collector: The collector object containing the data. + :param dict splits_dict: The dictionary containing the number of + elements in each stage. + :return: The dictionary containing the dataset objects. + :rtype: dict """ # ----------- Auxiliary function ------------ @@ -422,6 +479,15 @@ class PinaDataModule(LightningDataModule): return dataset_dict def _create_dataloader(self, split, dataset): + """ " + Create the dataloader for the given split. + + :param str split: The split on which to create the dataloader. + :param str dataset: The dataset to be used for the dataloader. + :return: The dataloader for the given split. + :rtype: torch.utils.data.DataLoader + """ + shuffle = self.shuffle if split == "train" else False # Suppress the warning about num_workers. # In many cases, especially for PINNs, @@ -470,6 +536,7 @@ class PinaDataModule(LightningDataModule): :return: The maximum length of the conditions. :rtype: dict """ + max_conditions_lengths = {} for k, v in self.collector_splits[split].items(): if self.batch_size is None: @@ -484,7 +551,10 @@ class PinaDataModule(LightningDataModule): def val_dataloader(self): """ - Create the validation dataloader + Create the validation dataloader. + + :return: The validation dataloader + :rtype: DataLoader """ return self._create_dataloader("val", self.val_dataset) @@ -509,20 +579,17 @@ class PinaDataModule(LightningDataModule): @staticmethod def _transfer_batch_to_device_dummy(batch, device, dataloader_idx): """ - Transfer the batch to the device. This method is called in the - training loop and is used to transfer the batch to the device. - This method is used when the batch size is None: batch has already - been transferred to the device. + Transfer the batch to the device. This method is used when the batch + size is None: batch has already been transferred to the device. :param list(tuple) batch: list of tuple where the first element of the tuple is the condition name and the second element is the data. - :param device: device to which the batch is transferred - :type device: torch.device - :param dataloader_idx: index of the dataloader - :type dataloader_idx: int + :param torch.device device: device to which the batch is transferred. + :param int dataloader_idx: index of the dataloader. :return: The batch transferred to the device. :rtype: list(tuple) """ + return batch def _transfer_batch_to_device(self, batch, device, dataloader_idx): @@ -531,12 +598,13 @@ class PinaDataModule(LightningDataModule): training loop and is used to transfer the batch to the device. :param dict batch: The batch to be transferred to the device. - :param device: The device to which the batch is transferred. - :type device: torch.device + :param torch.device device: The device to which the batch is + transferred. :param int dataloader_idx: The index of the dataloader. :return: The batch transferred to the device. :rtype: list(tuple) """ + batch = [ ( k, @@ -552,8 +620,18 @@ class PinaDataModule(LightningDataModule): @staticmethod def _check_slit_sizes(train_size, test_size, val_size): """ - Check if the splits are correct + Check if the splits are correct. The splits sizes must be positive and + the sum of the splits must be 1. + + :param float train_size: The size of the training split. + :param float test_size: The size of the testing split. + :param float val_size: The size of the validation split. + + :raises ValueError: If at least one of the splits is negative. + :raises ValueError: If the sum of the splits is different + from 1. """ + if train_size < 0 or test_size < 0 or val_size < 0: raise ValueError("The splits must be positive") if abs(train_size + test_size + val_size - 1) > 1e-6: @@ -567,6 +645,7 @@ class PinaDataModule(LightningDataModule): :return: The input points for training. :rtype dict """ + to_return = {} if hasattr(self, "train_dataset") and self.train_dataset is not None: to_return["train"] = self.train_dataset.input diff --git a/pina/data/dataset.py b/pina/data/dataset.py index e716621..ec405e1 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -1,8 +1,8 @@ """ -This module provide basic data management functionalities +Module for the PINA dataset """ -from abc import abstractmethod +from abc import abstractmethod, ABC from torch.utils.data import Dataset from torch_geometric.data import Data from ..graph import Graph, LabelBatch @@ -15,9 +15,10 @@ class PinaDatasetFactory: Depending on the type inside the conditions, it creates a different dataset object: - - :class:`PinaTensorDataset` for `torch.Tensor` - - :class:`PinaGraphDataset` for `list` of `torch_geometric.data.Data` - objects + - :class:`PinaTensorDataset` for handling :class:`torch.Tensor` and + :class:`LabelTensor` data. + - :class:`PinaGraphDataset` for handling :class:`Graph` and :class:`Data` + data. """ def __new__(cls, conditions_dict, **kwargs): @@ -28,7 +29,8 @@ class PinaDatasetFactory: :class:`PinaGraphDataset`, otherwise returns a :class:`PinaTensorDataset`. - :param dict conditions_dict: Dictionary containing the conditions. + :param dict conditions_dict: Dictionary containing all the conditions + to be included in the dataset instance. :return: A subclass of :class:`PinaDataset`. :rtype: :class:`PinaTensorDataset` | :class:`PinaGraphDataset` @@ -50,11 +52,11 @@ class PinaDatasetFactory: @staticmethod def _is_graph_dataset(conditions_dict): """ - Check if a graph is present in the conditions. + Check if a graph is present in the conditions (at least one time). :param conditions_dict: Dictionary containing the conditions. :type conditions_dict: dict - :return: True if a graph is present in the conditions, False otherwise + :return: True if a graph is present in the conditions, False otherwise. :rtype: bool """ @@ -68,25 +70,28 @@ class PinaDatasetFactory: return False -class PinaDataset(Dataset): +class PinaDataset(Dataset, ABC): """ - Abstract class for the PINA dataset + Abstract class for the PINA dataset. It defines the common interface for + the :class:`PinaTensorDataset` and :class:`PinaGraphDataset` classes. """ def __init__( self, conditions_dict, max_conditions_lengths, automatic_batching ): """ - Initialize the :class:`PinaDataset`. + Initialize a :class:`PinaDataset` instance by storing the provided + conditions dictionary, the maximum number of conditions to consider, + and the automatic batching flag. - Stores the conditions dictionary, the maximum number of conditions to - consider, and the automatic batching flag. - - :param dict conditions_dict: Dictionary containing the conditions. - :param dict max_conditions_lengths: Maximum number of data points to - consider in a single batch for each condition. - :param bool automatic_batching: Whether PyTorch automatic batching is - enabled in :class:`PinaDataModule`. + :param conditions_dict: Dictionary containing the conditions. + :type conditions_dict: dict + :param max_conditions_lengths: Specifies the maximum number of data + points to include in a single batch for each condition. + :type max_conditions_lengths: dict + :param automatic_batching: Indicates whether PyTorch automatic batching + is enabled in :class:`PinaDataModule`. + :type automatic_batching: bool """ # Store the conditions dictionary @@ -107,9 +112,9 @@ class PinaDataset(Dataset): def _get_max_len(self): """ - Returns the length of the longest condition in the dataset + Returns the length of the longest condition in the dataset. - :return: Length of the longest condition in the dataset + :return: Length of the longest condition in the dataset. :rtype: int """ @@ -129,9 +134,9 @@ class PinaDataset(Dataset): Return the index itself. This is used when automatic batching is disabled to postpone the data retrieval to the dataloader. - :param idx: Index + :param idx: Index. :type idx: int - :return: Index + :return: Index. :rtype: int """ @@ -143,8 +148,8 @@ class PinaDataset(Dataset): Return the data at the given index in the dataset. This is used when automatic batching is enabled. - :param int idx: Index - :return: A dictionary containing the data at the given index + :param int idx: Index. + :return: A dictionary containing the data at the given index. :rtype: dict """ @@ -156,23 +161,25 @@ class PinaDataset(Dataset): def get_all_data(self): """ - Return all data in the dataset + Return all data in the dataset. - :return: All data in the dataset + :return: A dictionary containing all the data in the dataset. :rtype: dict """ + index = list(range(len(self))) return self.fetch_from_idx_list(index) def fetch_from_idx_list(self, idx): """ - Return data from the dataset given a list of indices + Return data from the dataset given a list of indices. - :param idx: List of indices + :param idx: List of indices. :type idx: list - :return: Data from the dataset + :return: A dictionary containing the data at the given indices. :rtype: dict """ + to_return_dict = {} for condition, data in self.conditions_dict.items(): # Get the indices for the current condition @@ -190,30 +197,27 @@ class PinaDataset(Dataset): @abstractmethod def _retrive_data(self, data, idx_list): """ - Retrieve data from the dataset given a list of indices - - :param dict data: Dictionary containing the data - :param list idx_list: List of indices to retrieve - :return: Dictionary containing the data at the given indices - :rtype: dict + Abstract method to retrieve data from the dataset given a list of + indices. """ class PinaTensorDataset(PinaDataset): """ - Class for the PINA dataset with torch.Tensor data + Dataset class for the PINA dataset with :class:`torch.Tensor` and + :class:`LabelTensor` data. """ # Override _retrive_data method for torch.Tensor data def _retrive_data(self, data, idx_list): """ - Retrieve data from the dataset given a list of indices + Retrieve data from the dataset given a list of indices. :param data: Dictionary containing the data - (only torch.Tensor/LableTensor) + (only torch.Tensor/LableTensor). :type data: dict - :param list(int) idx_list: indices to retrieve - :return: Dictionary containing the data at the given indices + :param list(int) idx_list: indices to retrieve. + :return: Dictionary containing the data at the given indices. :rtype: dict """ @@ -222,9 +226,9 @@ class PinaTensorDataset(PinaDataset): @property def input(self): """ - Method to return all input points from the dataset. + Return the input data for the dataset. - :return: Dictionary containing the input points + :return: Dictionary containing the input points. :rtype: dict """ return {k: v["input"] for k, v in self.conditions_dict.items()} @@ -232,15 +236,17 @@ class PinaTensorDataset(PinaDataset): class PinaGraphDataset(PinaDataset): """ - Class for the PINA dataset with torch_geometric.data.Data data + Dataset class for the PINA dataset with :class:`torch_geometric.data.Data` + and :class:`Graph` data. """ def _create_graph_batch(self, data): """ - Create a LabelBatch object from a list of Data objects. + Create a LabelBatch object from a list of + :class:`torch_geometric.data.Data` objects. - :param data: List of Data or Graph objects - :type data: list(Data) | list(Graph) + :param data: List of items to collate in a single batch. + :type data: list(torch_geometric.data.Data) | list(Graph) :return: LabelBatch object all the graph collated in a single batch disconnected graphs. :rtype: LabelBatch @@ -255,7 +261,7 @@ class PinaGraphDataset(PinaDataset): :param data: torch.Tensor object of shape (N, ...) where N is the number of data points. :type data: torch.Tensor | LabelTensor - :return: reshaped torch.Tensor or LabelTensor object + :return: reshaped torch.Tensor or LabelTensor object. :rtype: torch.Tensor | LabelTensor """ out = data.reshape(-1, *data.shape[2:]) @@ -263,12 +269,13 @@ class PinaGraphDataset(PinaDataset): def create_batch(self, data): """ - Create a Batch object from a list of Data objects. + Create a Batch object from a list of :class:`torch_geometric.data.Data` + objects. - :param data: List of Data objects + :param data: List of items to collate in a single batch. :type data: list - :return: Batch object - :rtype: Batch or PinaBatch + :return: Batch object. + :rtype: Batch | PinaBatch """ if isinstance(data[0], Data): @@ -278,13 +285,14 @@ class PinaGraphDataset(PinaDataset): # Override _retrive_data method for graph handling def _retrive_data(self, data, idx_list): """ - Retrieve data from the dataset given a list of indices + Retrieve data from the dataset given a list of indices. - :param dict data: dictionary containing the data - :param list idx_list: list of indices to retrieve - :return: dictionary containing the data at the given indices + :param dict data: Dictionary containing the data. + :param list idx_list: List of indices to retrieve. + :return: Dictionary containing the data at the given indices. :rtype: dict """ + # Return the data from the current condition # If the data is a list of Data objects, create a Batch object # If the data is a list of torch.Tensor objects, create a torch.Tensor