From 18b02f43c5728af8ca7fd1dc072bdea8b2fc35a4 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Thu, 13 Nov 2025 14:01:18 +0100 Subject: [PATCH] fix some codacy warnings --- pina/data/data_module.py | 9 ++- pina/data/dataloader.py | 127 ++++++++++++++++++++++++++++----------- pina/data/dataset.py | 22 +++++-- 3 files changed, 112 insertions(+), 46 deletions(-) diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 5e1b006..f1910f8 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -255,7 +255,7 @@ class PinaDataModule(LightningDataModule): dataset_dict[key].update({condition_name: data}) return dataset_dict - def _create_dataloader(self, split, dataset): + def _create_dataloader(self, dataset): """ " Create the dataloader for the given split. @@ -280,7 +280,6 @@ class PinaDataModule(LightningDataModule): batch_size=self.batch_size, shuffle=self.shuffle, num_workers=self.num_workers, - collate_fn=None, common_batch_size=self.common_batch_size, separate_conditions=self.separate_conditions, ) @@ -292,7 +291,7 @@ class PinaDataModule(LightningDataModule): :return: The validation dataloader :rtype: torch.utils.data.DataLoader """ - return self._create_dataloader("val", self.val_dataset) + return self._create_dataloader(self.val_dataset) def train_dataloader(self): """ @@ -301,7 +300,7 @@ class PinaDataModule(LightningDataModule): :return: The training dataloader :rtype: torch.utils.data.DataLoader """ - return self._create_dataloader("train", self.train_dataset) + return self._create_dataloader(self.train_dataset) def test_dataloader(self): """ @@ -310,7 +309,7 @@ class PinaDataModule(LightningDataModule): :return: The testing dataloader :rtype: torch.utils.data.DataLoader """ - return self._create_dataloader("test", self.test_dataset) + return self._create_dataloader(self.test_dataset) @staticmethod def _transfer_batch_to_device_dummy(batch, device, dataloader_idx): diff --git a/pina/data/dataloader.py b/pina/data/dataloader.py index 97c8bc6..29b3673 100644 --- a/pina/data/dataloader.py +++ b/pina/data/dataloader.py @@ -1,11 +1,17 @@ -from torch.utils.data import DataLoader +"""DataLoader module for PinaDataset.""" + +import itertools from functools import partial +import torch +from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import SequentialSampler -import torch class DummyDataloader: + """ + DataLoader that returns the entire dataset in a single batch. + """ def __init__(self, dataset): """ @@ -24,18 +30,18 @@ class DummyDataloader: .. note:: This dataloader is used when the batch size is ``None``. """ - print("Using DummyDataloader") - if ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - ): + # Handle distributed environment + if PinaSampler.is_distributed(): + # Get rank and world size rank = torch.distributed.get_rank() world_size = torch.distributed.get_world_size() + # Ensure dataset is large enough if len(dataset) < world_size: raise RuntimeError( "Dimension of the dataset smaller than world size." " Increase the size of the partition or use a single GPU" ) + # Split dataset among processes idx, i = [], rank while i < len(dataset): idx.append(i) @@ -43,15 +49,28 @@ class DummyDataloader: else: idx = list(range(len(dataset))) - self.dataset = dataset._getitem_from_list(idx) + self.dataset = dataset.getitem_from_list(idx) def __iter__(self): + """ + Iterate over the dataloader. + """ return self def __len__(self): + """ + Return the length of the dataloader, which is always 1. + :return: The length of the dataloader. + :rtype: int + """ return 1 def __next__(self): + """ + Return the entire dataset as a single batch. + :return: The entire dataset. + :rtype: dict + """ return self.dataset @@ -70,10 +89,7 @@ class PinaSampler: :rtype: :class:`torch.utils.data.Sampler` """ - if ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - ): + if cls.is_distributed(): sampler = DistributedSampler(dataset, shuffle=shuffle) else: if shuffle: @@ -82,6 +98,18 @@ class PinaSampler: sampler = SequentialSampler(dataset) return sampler + @staticmethod + def is_distributed(): + """ + Check if the sampler is distributed. + :return: True if the sampler is distributed, False otherwise. + :rtype: bool + """ + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + ) + def _collect_items(batch): """ @@ -97,11 +125,12 @@ def _collect_items(batch): def collate_fn_custom(batch, dataset): """ - Override the default collate function to handle datasets without automatic batching. + Override the default collate function to handle datasets without automatic + batching. :param batch: List of indices from the dataset. :param dataset: The PinaDataset instance (must be provided). """ - return dataset._getitem_from_list(batch) + return dataset.getitem_from_list(batch) def collate_fn_default(batch, stack_fn): @@ -109,7 +138,6 @@ def collate_fn_default(batch, stack_fn): Default collate function that simply returns the batch as is. :param batch: List of data samples. """ - print("Using default collate function") to_return = _collect_items(batch) return {k: stack_fn[k](v) for k, v in to_return.items()} @@ -123,30 +151,36 @@ class PinaDataLoader: self, dataset_dict, batch_size, - shuffle=False, num_workers=0, - collate_fn=None, + shuffle=False, common_batch_size=True, separate_conditions=False, ): self.dataset_dict = dataset_dict self.batch_size = batch_size - self.shuffle = shuffle self.num_workers = num_workers - self.collate_fn = collate_fn + self.shuffle = shuffle self.separate_conditions = separate_conditions + # Batch size None means we want to load the entire dataset in a single + # batch if batch_size is None: batch_size_per_dataset = { split: None for split in dataset_dict.keys() } else: - if common_batch_size: + # Compute batch size per dataset + if common_batch_size: # all datasets have the same batch size + # (the sum of the batch sizes is equal to + # n_conditions * batch_size) batch_size_per_dataset = { split: batch_size for split in dataset_dict.keys() } - else: + else: # batch size proportional to dataset size (the sum of the + # batch sizes is equal to the specified batch size) batch_size_per_dataset = self._compute_batch_size() + + # Creaete a dataloader per dataset self.dataloaders = { split: self._create_dataloader( dataset, batch_size_per_dataset[split] @@ -158,21 +192,26 @@ class PinaDataLoader: """ Compute an appropriate batch size for the given dataset. """ + # Compute number of elements per dataset elements_per_dataset = { dataset_name: len(dataset) for dataset_name, dataset in self.dataset_dict.items() } + # Compute the total number of elements total_elements = sum(el for el in elements_per_dataset.values()) + # Compute the portion of each dataset portion_per_dataset = { name: el / total_elements for name, el in elements_per_dataset.items() } + # Compute batch size per dataset. Ensure at least 1 element per + # dataset. batch_size_per_dataset = { name: max(1, int(portion * self.batch_size)) for name, portion in portion_per_dataset.items() } + # Adjust batch sizes to match the specified total batch size tot_el_per_batch = sum(el for el in batch_size_per_dataset.values()) - if self.batch_size > tot_el_per_batch: difference = self.batch_size - tot_el_per_batch while difference > 0: @@ -194,33 +233,45 @@ class PinaDataLoader: return batch_size_per_dataset def _create_dataloader(self, dataset, batch_size): - print(batch_size) - if batch_size is None: + """ + Create the dataloader for the given dataset. + """ + # If batch size is None, use DummyDataloader + if batch_size is None or batch_size >= len(dataset): return DummyDataloader(dataset) + # Determine the appropriate collate function if not dataset.automatic_batching: collate_fn = partial(collate_fn_custom, dataset=dataset) else: collate_fn = partial(collate_fn_default, stack_fn=dataset.stack_fn) + + # Create and return the dataloader return DataLoader( dataset, batch_size=batch_size, - num_workers=self.num_workers, collate_fn=collate_fn, + num_workers=self.num_workers, sampler=PinaSampler(dataset, shuffle=self.shuffle), ) def __len__(self): + """ + Return the length of the dataloader. + :return: The length of the dataloader. + :rtype: int + """ + # If separate conditions, return sum of lengths of all dataloaders + # else, return max length among dataloaders if self.separate_conditions: return sum(len(dl) for dl in self.dataloaders.values()) return max(len(dl) for dl in self.dataloaders.values()) def __iter__(self): """ - Restituisce un iteratore che produce dizionari di batch. - - Itera per un numero di passi pari al dataloader più lungo (come da __len__) - e fa ricominciare i dataloader più corti quando si esauriscono. + Iterate over the dataloader. + :return: Yields batches from the dataloader. + :rtype: dict """ if self.separate_conditions: for split, dl in self.dataloaders.items(): @@ -228,15 +279,19 @@ class PinaDataLoader: yield {split: batch} return - iterators = {split: iter(dl) for split, dl in self.dataloaders.items()} + iterators = { + split: itertools.cycle(dl) for split, dl in self.dataloaders.items() + } + for _ in range(len(self)): batch_dict = {} for split, it in iterators.items(): - try: - batch = next(it) - except StopIteration: - new_it = iter(self.dataloaders[split]) - iterators[split] = new_it - batch = next(new_it) + + # Iterate through each dataloader and get the next batch + batch = next(it, None) + # Check if batch is None (in case of uneven lengths) + if batch is None: + return + batch_dict[split] = batch yield batch_dict diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 674b512..bcb44aa 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -9,26 +9,38 @@ from ..label_tensor import LabelTensor class PinaDatasetFactory: """ - TODO: Update docstring + Factory class to create PINA datasets based on the provided conditions + dictionary. + :param dict conditions_dict: A dictionary where keys are condition names + and values are dictionaries containing the associated data. + :return: A dictionary mapping condition names to their respective + :class:`PinaDataset` instances. """ def __new__(cls, conditions_dict, **kwargs): """ - TODO: Update docstring + Create PINA dataset instances based on the provided conditions + dictionary. + :param dict conditions_dict: A dictionary where keys are condition names + and values are dictionaries containing the associated data. + :return: A dictionary mapping condition names to their respective + :class:`PinaDataset` instances. """ # Check if conditions_dict is empty if len(conditions_dict) == 0: raise ValueError("No conditions provided") - dataset_dict = {} + dataset_dict = {} # Dictionary to hold the created datasets # Check is a Graph is present in the conditions for name, data in conditions_dict.items(): + # Validate that data is a dictionary if not isinstance(data, dict): raise ValueError( f"Condition '{name}' data must be a dictionary" ) + # Create PinaDataset instance for each condition dataset_dict[name] = PinaDataset(data, **kwargs) return dataset_dict @@ -90,7 +102,7 @@ class PinaDataset(Dataset): } return idx - def _getitem_from_list(self, idx_list): + def getitem_from_list(self, idx_list): """ Return data from the dataset given a list of indices. @@ -101,7 +113,7 @@ class PinaDataset(Dataset): to_return = {} for field_name, data in self.data.items(): - if self.stack_fn[field_name] == LabelBatch.from_data_list: + if self.stack_fn[field_name] is LabelBatch.from_data_list: to_return[field_name] = self.stack_fn[field_name]( [data[i] for i in idx_list] )