fix some codacy warnings

This commit is contained in:
FilippoOlivo
2025-11-13 14:01:18 +01:00
parent 6bb44052b0
commit 18b02f43c5
3 changed files with 112 additions and 46 deletions

View File

@@ -1,11 +1,17 @@
from torch.utils.data import DataLoader
"""DataLoader module for PinaDataset."""
import itertools
from functools import partial
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import SequentialSampler
import torch
class DummyDataloader:
"""
DataLoader that returns the entire dataset in a single batch.
"""
def __init__(self, dataset):
"""
@@ -24,18 +30,18 @@ class DummyDataloader:
.. note::
This dataloader is used when the batch size is ``None``.
"""
print("Using DummyDataloader")
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
):
# Handle distributed environment
if PinaSampler.is_distributed():
# Get rank and world size
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
# Ensure dataset is large enough
if len(dataset) < world_size:
raise RuntimeError(
"Dimension of the dataset smaller than world size."
" Increase the size of the partition or use a single GPU"
)
# Split dataset among processes
idx, i = [], rank
while i < len(dataset):
idx.append(i)
@@ -43,15 +49,28 @@ class DummyDataloader:
else:
idx = list(range(len(dataset)))
self.dataset = dataset._getitem_from_list(idx)
self.dataset = dataset.getitem_from_list(idx)
def __iter__(self):
"""
Iterate over the dataloader.
"""
return self
def __len__(self):
"""
Return the length of the dataloader, which is always 1.
:return: The length of the dataloader.
:rtype: int
"""
return 1
def __next__(self):
"""
Return the entire dataset as a single batch.
:return: The entire dataset.
:rtype: dict
"""
return self.dataset
@@ -70,10 +89,7 @@ class PinaSampler:
:rtype: :class:`torch.utils.data.Sampler`
"""
if (
torch.distributed.is_available()
and torch.distributed.is_initialized()
):
if cls.is_distributed():
sampler = DistributedSampler(dataset, shuffle=shuffle)
else:
if shuffle:
@@ -82,6 +98,18 @@ class PinaSampler:
sampler = SequentialSampler(dataset)
return sampler
@staticmethod
def is_distributed():
"""
Check if the sampler is distributed.
:return: True if the sampler is distributed, False otherwise.
:rtype: bool
"""
return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
)
def _collect_items(batch):
"""
@@ -97,11 +125,12 @@ def _collect_items(batch):
def collate_fn_custom(batch, dataset):
"""
Override the default collate function to handle datasets without automatic batching.
Override the default collate function to handle datasets without automatic
batching.
:param batch: List of indices from the dataset.
:param dataset: The PinaDataset instance (must be provided).
"""
return dataset._getitem_from_list(batch)
return dataset.getitem_from_list(batch)
def collate_fn_default(batch, stack_fn):
@@ -109,7 +138,6 @@ def collate_fn_default(batch, stack_fn):
Default collate function that simply returns the batch as is.
:param batch: List of data samples.
"""
print("Using default collate function")
to_return = _collect_items(batch)
return {k: stack_fn[k](v) for k, v in to_return.items()}
@@ -123,30 +151,36 @@ class PinaDataLoader:
self,
dataset_dict,
batch_size,
shuffle=False,
num_workers=0,
collate_fn=None,
shuffle=False,
common_batch_size=True,
separate_conditions=False,
):
self.dataset_dict = dataset_dict
self.batch_size = batch_size
self.shuffle = shuffle
self.num_workers = num_workers
self.collate_fn = collate_fn
self.shuffle = shuffle
self.separate_conditions = separate_conditions
# Batch size None means we want to load the entire dataset in a single
# batch
if batch_size is None:
batch_size_per_dataset = {
split: None for split in dataset_dict.keys()
}
else:
if common_batch_size:
# Compute batch size per dataset
if common_batch_size: # all datasets have the same batch size
# (the sum of the batch sizes is equal to
# n_conditions * batch_size)
batch_size_per_dataset = {
split: batch_size for split in dataset_dict.keys()
}
else:
else: # batch size proportional to dataset size (the sum of the
# batch sizes is equal to the specified batch size)
batch_size_per_dataset = self._compute_batch_size()
# Creaete a dataloader per dataset
self.dataloaders = {
split: self._create_dataloader(
dataset, batch_size_per_dataset[split]
@@ -158,21 +192,26 @@ class PinaDataLoader:
"""
Compute an appropriate batch size for the given dataset.
"""
# Compute number of elements per dataset
elements_per_dataset = {
dataset_name: len(dataset)
for dataset_name, dataset in self.dataset_dict.items()
}
# Compute the total number of elements
total_elements = sum(el for el in elements_per_dataset.values())
# Compute the portion of each dataset
portion_per_dataset = {
name: el / total_elements
for name, el in elements_per_dataset.items()
}
# Compute batch size per dataset. Ensure at least 1 element per
# dataset.
batch_size_per_dataset = {
name: max(1, int(portion * self.batch_size))
for name, portion in portion_per_dataset.items()
}
# Adjust batch sizes to match the specified total batch size
tot_el_per_batch = sum(el for el in batch_size_per_dataset.values())
if self.batch_size > tot_el_per_batch:
difference = self.batch_size - tot_el_per_batch
while difference > 0:
@@ -194,33 +233,45 @@ class PinaDataLoader:
return batch_size_per_dataset
def _create_dataloader(self, dataset, batch_size):
print(batch_size)
if batch_size is None:
"""
Create the dataloader for the given dataset.
"""
# If batch size is None, use DummyDataloader
if batch_size is None or batch_size >= len(dataset):
return DummyDataloader(dataset)
# Determine the appropriate collate function
if not dataset.automatic_batching:
collate_fn = partial(collate_fn_custom, dataset=dataset)
else:
collate_fn = partial(collate_fn_default, stack_fn=dataset.stack_fn)
# Create and return the dataloader
return DataLoader(
dataset,
batch_size=batch_size,
num_workers=self.num_workers,
collate_fn=collate_fn,
num_workers=self.num_workers,
sampler=PinaSampler(dataset, shuffle=self.shuffle),
)
def __len__(self):
"""
Return the length of the dataloader.
:return: The length of the dataloader.
:rtype: int
"""
# If separate conditions, return sum of lengths of all dataloaders
# else, return max length among dataloaders
if self.separate_conditions:
return sum(len(dl) for dl in self.dataloaders.values())
return max(len(dl) for dl in self.dataloaders.values())
def __iter__(self):
"""
Restituisce un iteratore che produce dizionari di batch.
Itera per un numero di passi pari al dataloader più lungo (come da __len__)
e fa ricominciare i dataloader più corti quando si esauriscono.
Iterate over the dataloader.
:return: Yields batches from the dataloader.
:rtype: dict
"""
if self.separate_conditions:
for split, dl in self.dataloaders.items():
@@ -228,15 +279,19 @@ class PinaDataLoader:
yield {split: batch}
return
iterators = {split: iter(dl) for split, dl in self.dataloaders.items()}
iterators = {
split: itertools.cycle(dl) for split, dl in self.dataloaders.items()
}
for _ in range(len(self)):
batch_dict = {}
for split, it in iterators.items():
try:
batch = next(it)
except StopIteration:
new_it = iter(self.dataloaders[split])
iterators[split] = new_it
batch = next(new_it)
# Iterate through each dataloader and get the next batch
batch = next(it, None)
# Check if batch is None (in case of uneven lengths)
if batch is None:
return
batch_dict[split] = batch
yield batch_dict