diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 7f467b7..5e1b006 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -7,52 +7,11 @@ different types of Datasets defined in PINA. import warnings from lightning.pytorch import LightningDataModule import torch -from torch_geometric.data import Data -from torch.utils.data import DataLoader, SequentialSampler, RandomSampler -from torch.utils.data.distributed import DistributedSampler from ..label_tensor import LabelTensor from .dataset import PinaDatasetFactory from .dataloader import PinaDataLoader -class PinaSampler: - """ - This class is used to create the sampler instance based on the shuffle - parameter and the environment in which the code is running. - """ - - def __new__(cls, dataset): - """ - Instantiate and initialize the sampler. - - :param PinaDataset dataset: The dataset from which to sample. - :return: The sampler instance. - :rtype: :class:`torch.utils.data.Sampler` - """ - - if ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - ): - sampler = DistributedSampler(dataset) - else: - sampler = SequentialSampler(dataset) - return sampler - - -def DataloaderCollector(): - - def __init__(self, dataloader_list): - """ - Initialize the object. - """ - assert isinstance(dataloader_list, list) - assert all( - isinstance(dataloader, DataLoader) for dataloader in dataloader_list - ) - self.dataloader_list = dataloader_list - - class PinaDataModule(LightningDataModule): """ This class extends :class:`~lightning.pytorch.core.LightningDataModule`, @@ -68,7 +27,8 @@ class PinaDataModule(LightningDataModule): val_size=0.1, batch_size=None, shuffle=True, - repeat=False, + common_batch_size=True, + separate_conditions=False, automatic_batching=None, num_workers=0, pin_memory=False, @@ -89,11 +49,12 @@ class PinaDataModule(LightningDataModule): Default is ``None``. :param bool shuffle: Whether to shuffle the dataset before splitting. Default ``True``. - :param bool repeat: If ``True``, in case of batch size larger than the - number of elements in a specific condition, the elements are - repeated until the batch size is reached. If ``False``, the number - of elements in the batch is the minimum between the batch size and - the number of elements in the condition. Default is ``False``. + :param bool common_batch_size: If ``True``, the same batch size is used + for all conditions. If ``False``, each condition can have its own + batch size, proportional to the size of the dataset in that + condition. Default is ``True``. + :param bool separate_conditions: If ``True``, dataloaders for each + condition are iterated separately. Default is ``False``. :param automatic_batching: If ``True``, automatic PyTorch batching is performed, which consists of extracting one element at a time from the dataset and collating them into a batch. This is useful @@ -123,7 +84,8 @@ class PinaDataModule(LightningDataModule): # Store fixed attributes self.batch_size = batch_size self.shuffle = shuffle - self.repeat = repeat + self.common_batch_size = common_batch_size + self.separate_conditions = separate_conditions self.automatic_batching = automatic_batching # If batch size is None, num_workers has no effect @@ -194,23 +156,16 @@ class PinaDataModule(LightningDataModule): if stage == "fit" or stage is None: self.train_dataset = PinaDatasetFactory( self.data_splits["train"], - # max_conditions_lengths=self.find_max_conditions_lengths( - # "train" - # ), automatic_batching=self.automatic_batching, ) if "val" in self.data_splits.keys(): self.val_dataset = PinaDatasetFactory( self.data_splits["val"], - # max_conditions_lengths=self.find_max_conditions_lengths( - # "val" - # ), automatic_batching=self.automatic_batching, ) elif stage == "test": self.test_dataset = PinaDatasetFactory( self.data_splits["test"], - # max_conditions_lengths=self.find_max_conditions_lengths("test"), automatic_batching=self.automatic_batching, ) else: @@ -326,30 +281,10 @@ class PinaDataModule(LightningDataModule): shuffle=self.shuffle, num_workers=self.num_workers, collate_fn=None, - common_batch_size=True, + common_batch_size=self.common_batch_size, + separate_conditions=self.separate_conditions, ) - def find_max_conditions_lengths(self, split): - """ - Define the maximum length for each conditions. - - :param dict split: The split of the dataset. - :return: The maximum length per condition. - :rtype: dict - """ - - max_conditions_lengths = {} - for k, v in self.data_splits[split].items(): - if self.batch_size is None: - max_conditions_lengths[k] = len(v["input"]) - elif self.repeat: - max_conditions_lengths[k] = self.batch_size - else: - max_conditions_lengths[k] = min( - len(v["input"]), self.batch_size - ) - return max_conditions_lengths - def val_dataloader(self): """ Create the validation dataloader. diff --git a/pina/data/dataloader.py b/pina/data/dataloader.py index 6855171..97c8bc6 100644 --- a/pina/data/dataloader.py +++ b/pina/data/dataloader.py @@ -127,14 +127,14 @@ class PinaDataLoader: num_workers=0, collate_fn=None, common_batch_size=True, + separate_conditions=False, ): self.dataset_dict = dataset_dict self.batch_size = batch_size self.shuffle = shuffle self.num_workers = num_workers self.collate_fn = collate_fn - - print(batch_size) + self.separate_conditions = separate_conditions if batch_size is None: batch_size_per_dataset = { @@ -211,6 +211,8 @@ class PinaDataLoader: ) def __len__(self): + if self.separate_conditions: + return sum(len(dl) for dl in self.dataloaders.values()) return max(len(dl) for dl in self.dataloaders.values()) def __iter__(self): @@ -220,26 +222,21 @@ class PinaDataLoader: Itera per un numero di passi pari al dataloader più lungo (come da __len__) e fa ricominciare i dataloader più corti quando si esauriscono. """ - # 1. Crea un iteratore per ogni dataloader + if self.separate_conditions: + for split, dl in self.dataloaders.items(): + for batch in dl: + yield {split: batch} + return + iterators = {split: iter(dl) for split, dl in self.dataloaders.items()} - - # 2. Itera per il numero di batch del dataloader più lungo for _ in range(len(self)): - - # 3. Prepara il dizionario di batch per questo step batch_dict = {} - - # 4. Ottieni il prossimo batch da ogni iteratore for split, it in iterators.items(): try: batch = next(it) except StopIteration: - # 5. Se un iteratore è esaurito, resettalo e prendi il primo batch new_it = iter(self.dataloaders[split]) - iterators[split] = new_it # Salva il nuovo iteratore + iterators[split] = new_it batch = next(new_it) - batch_dict[split] = batch - - # 6. Restituisci il dizionario di batch yield batch_dict diff --git a/pina/data/dataset.py b/pina/data/dataset.py index d829770..88e86fe 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -1,41 +1,20 @@ """Module for the PINA dataset classes.""" +import torch from torch.utils.data import Dataset from torch_geometric.data import Data from ..graph import Graph, LabelBatch from ..label_tensor import LabelTensor -import torch class PinaDatasetFactory: """ - Factory class for the PINA dataset. - - Depending on the data type inside the conditions, it instanciate an object - belonging to the appropriate subclass of - :class:`~pina.data.dataset.PinaDataset`. The possible subclasses are: - - - :class:`~pina.data.dataset.PinaTensorDataset`, for handling \ - :class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data. - - :class:`~pina.data.dataset.PinaGraphDataset`, for handling \ - :class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data. + TODO: Update docstring """ def __new__(cls, conditions_dict, **kwargs): """ - Instantiate the appropriate subclass of - :class:`~pina.data.dataset.PinaDataset`. - - If a graph is present in the conditions, returns a - :class:`~pina.data.dataset.PinaGraphDataset`, otherwise returns a - :class:`~pina.data.dataset.PinaTensorDataset`. - - :param dict conditions_dict: Dictionary containing all the conditions - to be included in the dataset instance. - :return: A subclass of :class:`~pina.data.dataset.PinaDataset`. - :rtype: PinaTensorDataset | PinaGraphDataset - - :raises ValueError: If an empty dictionary is provided. + TODO: Update docstring """ # Check if conditions_dict is empty @@ -50,28 +29,11 @@ class PinaDatasetFactory: raise ValueError( f"Condition '{name}' data must be a dictionary" ) - - # is_graph = cls._is_graph_dataset(conditions_dict) - # if is_graph: - # raise NotImplementedError("PinaGraphDataset is not implemented yet.") - - dataset_dict[name] = PinaTensorDataset(data, **kwargs) + dataset_dict[name] = PinaDataset(data, **kwargs) return dataset_dict - @staticmethod - def _is_graph_dataset(cond_data): - """ - TODO: Docstring - """ - # Iterate over the values of the current condition - for cond in cond_data.values(): - if isinstance(cond, (Data, Graph, list, tuple)): - return True - return False - - -class PinaTensorDataset(Dataset): +class PinaDataset(Dataset): """ Dataset class for the PINA dataset with :class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data. @@ -91,9 +53,8 @@ class PinaTensorDataset(Dataset): self.automatic_batching = ( automatic_batching if automatic_batching is not None else True ) - self.stack_fn = ( - {} - ) # LabelTensor.stack if any(isinstance(v, LabelTensor) for v in data_dict.values()) else torch.stack + self.stack_fn = {} + # Determine stacking functions for each data type (used in collate_fn) for k, v in data_dict.items(): if isinstance(v, LabelTensor): self.stack_fn[k] = LabelTensor.stack