fix some codacy warnings
This commit is contained in:
@@ -255,7 +255,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
dataset_dict[key].update({condition_name: data})
|
dataset_dict[key].update({condition_name: data})
|
||||||
return dataset_dict
|
return dataset_dict
|
||||||
|
|
||||||
def _create_dataloader(self, split, dataset):
|
def _create_dataloader(self, dataset):
|
||||||
""" "
|
""" "
|
||||||
Create the dataloader for the given split.
|
Create the dataloader for the given split.
|
||||||
|
|
||||||
@@ -280,7 +280,6 @@ class PinaDataModule(LightningDataModule):
|
|||||||
batch_size=self.batch_size,
|
batch_size=self.batch_size,
|
||||||
shuffle=self.shuffle,
|
shuffle=self.shuffle,
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_workers,
|
||||||
collate_fn=None,
|
|
||||||
common_batch_size=self.common_batch_size,
|
common_batch_size=self.common_batch_size,
|
||||||
separate_conditions=self.separate_conditions,
|
separate_conditions=self.separate_conditions,
|
||||||
)
|
)
|
||||||
@@ -292,7 +291,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
:return: The validation dataloader
|
:return: The validation dataloader
|
||||||
:rtype: torch.utils.data.DataLoader
|
:rtype: torch.utils.data.DataLoader
|
||||||
"""
|
"""
|
||||||
return self._create_dataloader("val", self.val_dataset)
|
return self._create_dataloader(self.val_dataset)
|
||||||
|
|
||||||
def train_dataloader(self):
|
def train_dataloader(self):
|
||||||
"""
|
"""
|
||||||
@@ -301,7 +300,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
:return: The training dataloader
|
:return: The training dataloader
|
||||||
:rtype: torch.utils.data.DataLoader
|
:rtype: torch.utils.data.DataLoader
|
||||||
"""
|
"""
|
||||||
return self._create_dataloader("train", self.train_dataset)
|
return self._create_dataloader(self.train_dataset)
|
||||||
|
|
||||||
def test_dataloader(self):
|
def test_dataloader(self):
|
||||||
"""
|
"""
|
||||||
@@ -310,7 +309,7 @@ class PinaDataModule(LightningDataModule):
|
|||||||
:return: The testing dataloader
|
:return: The testing dataloader
|
||||||
:rtype: torch.utils.data.DataLoader
|
:rtype: torch.utils.data.DataLoader
|
||||||
"""
|
"""
|
||||||
return self._create_dataloader("test", self.test_dataset)
|
return self._create_dataloader(self.test_dataset)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
|
def _transfer_batch_to_device_dummy(batch, device, dataloader_idx):
|
||||||
|
|||||||
@@ -1,11 +1,17 @@
|
|||||||
from torch.utils.data import DataLoader
|
"""DataLoader module for PinaDataset."""
|
||||||
|
|
||||||
|
import itertools
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
from torch.utils.data.sampler import SequentialSampler
|
from torch.utils.data.sampler import SequentialSampler
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class DummyDataloader:
|
class DummyDataloader:
|
||||||
|
"""
|
||||||
|
DataLoader that returns the entire dataset in a single batch.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, dataset):
|
def __init__(self, dataset):
|
||||||
"""
|
"""
|
||||||
@@ -24,18 +30,18 @@ class DummyDataloader:
|
|||||||
.. note::
|
.. note::
|
||||||
This dataloader is used when the batch size is ``None``.
|
This dataloader is used when the batch size is ``None``.
|
||||||
"""
|
"""
|
||||||
print("Using DummyDataloader")
|
# Handle distributed environment
|
||||||
if (
|
if PinaSampler.is_distributed():
|
||||||
torch.distributed.is_available()
|
# Get rank and world size
|
||||||
and torch.distributed.is_initialized()
|
|
||||||
):
|
|
||||||
rank = torch.distributed.get_rank()
|
rank = torch.distributed.get_rank()
|
||||||
world_size = torch.distributed.get_world_size()
|
world_size = torch.distributed.get_world_size()
|
||||||
|
# Ensure dataset is large enough
|
||||||
if len(dataset) < world_size:
|
if len(dataset) < world_size:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Dimension of the dataset smaller than world size."
|
"Dimension of the dataset smaller than world size."
|
||||||
" Increase the size of the partition or use a single GPU"
|
" Increase the size of the partition or use a single GPU"
|
||||||
)
|
)
|
||||||
|
# Split dataset among processes
|
||||||
idx, i = [], rank
|
idx, i = [], rank
|
||||||
while i < len(dataset):
|
while i < len(dataset):
|
||||||
idx.append(i)
|
idx.append(i)
|
||||||
@@ -43,15 +49,28 @@ class DummyDataloader:
|
|||||||
else:
|
else:
|
||||||
idx = list(range(len(dataset)))
|
idx = list(range(len(dataset)))
|
||||||
|
|
||||||
self.dataset = dataset._getitem_from_list(idx)
|
self.dataset = dataset.getitem_from_list(idx)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
|
"""
|
||||||
|
Iterate over the dataloader.
|
||||||
|
"""
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Return the length of the dataloader, which is always 1.
|
||||||
|
:return: The length of the dataloader.
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
|
"""
|
||||||
|
Return the entire dataset as a single batch.
|
||||||
|
:return: The entire dataset.
|
||||||
|
:rtype: dict
|
||||||
|
"""
|
||||||
return self.dataset
|
return self.dataset
|
||||||
|
|
||||||
|
|
||||||
@@ -70,10 +89,7 @@ class PinaSampler:
|
|||||||
:rtype: :class:`torch.utils.data.Sampler`
|
:rtype: :class:`torch.utils.data.Sampler`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if (
|
if cls.is_distributed():
|
||||||
torch.distributed.is_available()
|
|
||||||
and torch.distributed.is_initialized()
|
|
||||||
):
|
|
||||||
sampler = DistributedSampler(dataset, shuffle=shuffle)
|
sampler = DistributedSampler(dataset, shuffle=shuffle)
|
||||||
else:
|
else:
|
||||||
if shuffle:
|
if shuffle:
|
||||||
@@ -82,6 +98,18 @@ class PinaSampler:
|
|||||||
sampler = SequentialSampler(dataset)
|
sampler = SequentialSampler(dataset)
|
||||||
return sampler
|
return sampler
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_distributed():
|
||||||
|
"""
|
||||||
|
Check if the sampler is distributed.
|
||||||
|
:return: True if the sampler is distributed, False otherwise.
|
||||||
|
:rtype: bool
|
||||||
|
"""
|
||||||
|
return (
|
||||||
|
torch.distributed.is_available()
|
||||||
|
and torch.distributed.is_initialized()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _collect_items(batch):
|
def _collect_items(batch):
|
||||||
"""
|
"""
|
||||||
@@ -97,11 +125,12 @@ def _collect_items(batch):
|
|||||||
|
|
||||||
def collate_fn_custom(batch, dataset):
|
def collate_fn_custom(batch, dataset):
|
||||||
"""
|
"""
|
||||||
Override the default collate function to handle datasets without automatic batching.
|
Override the default collate function to handle datasets without automatic
|
||||||
|
batching.
|
||||||
:param batch: List of indices from the dataset.
|
:param batch: List of indices from the dataset.
|
||||||
:param dataset: The PinaDataset instance (must be provided).
|
:param dataset: The PinaDataset instance (must be provided).
|
||||||
"""
|
"""
|
||||||
return dataset._getitem_from_list(batch)
|
return dataset.getitem_from_list(batch)
|
||||||
|
|
||||||
|
|
||||||
def collate_fn_default(batch, stack_fn):
|
def collate_fn_default(batch, stack_fn):
|
||||||
@@ -109,7 +138,6 @@ def collate_fn_default(batch, stack_fn):
|
|||||||
Default collate function that simply returns the batch as is.
|
Default collate function that simply returns the batch as is.
|
||||||
:param batch: List of data samples.
|
:param batch: List of data samples.
|
||||||
"""
|
"""
|
||||||
print("Using default collate function")
|
|
||||||
to_return = _collect_items(batch)
|
to_return = _collect_items(batch)
|
||||||
return {k: stack_fn[k](v) for k, v in to_return.items()}
|
return {k: stack_fn[k](v) for k, v in to_return.items()}
|
||||||
|
|
||||||
@@ -123,30 +151,36 @@ class PinaDataLoader:
|
|||||||
self,
|
self,
|
||||||
dataset_dict,
|
dataset_dict,
|
||||||
batch_size,
|
batch_size,
|
||||||
shuffle=False,
|
|
||||||
num_workers=0,
|
num_workers=0,
|
||||||
collate_fn=None,
|
shuffle=False,
|
||||||
common_batch_size=True,
|
common_batch_size=True,
|
||||||
separate_conditions=False,
|
separate_conditions=False,
|
||||||
):
|
):
|
||||||
self.dataset_dict = dataset_dict
|
self.dataset_dict = dataset_dict
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.shuffle = shuffle
|
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
self.collate_fn = collate_fn
|
self.shuffle = shuffle
|
||||||
self.separate_conditions = separate_conditions
|
self.separate_conditions = separate_conditions
|
||||||
|
|
||||||
|
# Batch size None means we want to load the entire dataset in a single
|
||||||
|
# batch
|
||||||
if batch_size is None:
|
if batch_size is None:
|
||||||
batch_size_per_dataset = {
|
batch_size_per_dataset = {
|
||||||
split: None for split in dataset_dict.keys()
|
split: None for split in dataset_dict.keys()
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
if common_batch_size:
|
# Compute batch size per dataset
|
||||||
|
if common_batch_size: # all datasets have the same batch size
|
||||||
|
# (the sum of the batch sizes is equal to
|
||||||
|
# n_conditions * batch_size)
|
||||||
batch_size_per_dataset = {
|
batch_size_per_dataset = {
|
||||||
split: batch_size for split in dataset_dict.keys()
|
split: batch_size for split in dataset_dict.keys()
|
||||||
}
|
}
|
||||||
else:
|
else: # batch size proportional to dataset size (the sum of the
|
||||||
|
# batch sizes is equal to the specified batch size)
|
||||||
batch_size_per_dataset = self._compute_batch_size()
|
batch_size_per_dataset = self._compute_batch_size()
|
||||||
|
|
||||||
|
# Creaete a dataloader per dataset
|
||||||
self.dataloaders = {
|
self.dataloaders = {
|
||||||
split: self._create_dataloader(
|
split: self._create_dataloader(
|
||||||
dataset, batch_size_per_dataset[split]
|
dataset, batch_size_per_dataset[split]
|
||||||
@@ -158,21 +192,26 @@ class PinaDataLoader:
|
|||||||
"""
|
"""
|
||||||
Compute an appropriate batch size for the given dataset.
|
Compute an appropriate batch size for the given dataset.
|
||||||
"""
|
"""
|
||||||
|
# Compute number of elements per dataset
|
||||||
elements_per_dataset = {
|
elements_per_dataset = {
|
||||||
dataset_name: len(dataset)
|
dataset_name: len(dataset)
|
||||||
for dataset_name, dataset in self.dataset_dict.items()
|
for dataset_name, dataset in self.dataset_dict.items()
|
||||||
}
|
}
|
||||||
|
# Compute the total number of elements
|
||||||
total_elements = sum(el for el in elements_per_dataset.values())
|
total_elements = sum(el for el in elements_per_dataset.values())
|
||||||
|
# Compute the portion of each dataset
|
||||||
portion_per_dataset = {
|
portion_per_dataset = {
|
||||||
name: el / total_elements
|
name: el / total_elements
|
||||||
for name, el in elements_per_dataset.items()
|
for name, el in elements_per_dataset.items()
|
||||||
}
|
}
|
||||||
|
# Compute batch size per dataset. Ensure at least 1 element per
|
||||||
|
# dataset.
|
||||||
batch_size_per_dataset = {
|
batch_size_per_dataset = {
|
||||||
name: max(1, int(portion * self.batch_size))
|
name: max(1, int(portion * self.batch_size))
|
||||||
for name, portion in portion_per_dataset.items()
|
for name, portion in portion_per_dataset.items()
|
||||||
}
|
}
|
||||||
|
# Adjust batch sizes to match the specified total batch size
|
||||||
tot_el_per_batch = sum(el for el in batch_size_per_dataset.values())
|
tot_el_per_batch = sum(el for el in batch_size_per_dataset.values())
|
||||||
|
|
||||||
if self.batch_size > tot_el_per_batch:
|
if self.batch_size > tot_el_per_batch:
|
||||||
difference = self.batch_size - tot_el_per_batch
|
difference = self.batch_size - tot_el_per_batch
|
||||||
while difference > 0:
|
while difference > 0:
|
||||||
@@ -194,33 +233,45 @@ class PinaDataLoader:
|
|||||||
return batch_size_per_dataset
|
return batch_size_per_dataset
|
||||||
|
|
||||||
def _create_dataloader(self, dataset, batch_size):
|
def _create_dataloader(self, dataset, batch_size):
|
||||||
print(batch_size)
|
"""
|
||||||
if batch_size is None:
|
Create the dataloader for the given dataset.
|
||||||
|
"""
|
||||||
|
# If batch size is None, use DummyDataloader
|
||||||
|
if batch_size is None or batch_size >= len(dataset):
|
||||||
return DummyDataloader(dataset)
|
return DummyDataloader(dataset)
|
||||||
|
|
||||||
|
# Determine the appropriate collate function
|
||||||
if not dataset.automatic_batching:
|
if not dataset.automatic_batching:
|
||||||
collate_fn = partial(collate_fn_custom, dataset=dataset)
|
collate_fn = partial(collate_fn_custom, dataset=dataset)
|
||||||
else:
|
else:
|
||||||
collate_fn = partial(collate_fn_default, stack_fn=dataset.stack_fn)
|
collate_fn = partial(collate_fn_default, stack_fn=dataset.stack_fn)
|
||||||
|
|
||||||
|
# Create and return the dataloader
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=self.num_workers,
|
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
|
num_workers=self.num_workers,
|
||||||
sampler=PinaSampler(dataset, shuffle=self.shuffle),
|
sampler=PinaSampler(dataset, shuffle=self.shuffle),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
"""
|
||||||
|
Return the length of the dataloader.
|
||||||
|
:return: The length of the dataloader.
|
||||||
|
:rtype: int
|
||||||
|
"""
|
||||||
|
# If separate conditions, return sum of lengths of all dataloaders
|
||||||
|
# else, return max length among dataloaders
|
||||||
if self.separate_conditions:
|
if self.separate_conditions:
|
||||||
return sum(len(dl) for dl in self.dataloaders.values())
|
return sum(len(dl) for dl in self.dataloaders.values())
|
||||||
return max(len(dl) for dl in self.dataloaders.values())
|
return max(len(dl) for dl in self.dataloaders.values())
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""
|
"""
|
||||||
Restituisce un iteratore che produce dizionari di batch.
|
Iterate over the dataloader.
|
||||||
|
:return: Yields batches from the dataloader.
|
||||||
Itera per un numero di passi pari al dataloader più lungo (come da __len__)
|
:rtype: dict
|
||||||
e fa ricominciare i dataloader più corti quando si esauriscono.
|
|
||||||
"""
|
"""
|
||||||
if self.separate_conditions:
|
if self.separate_conditions:
|
||||||
for split, dl in self.dataloaders.items():
|
for split, dl in self.dataloaders.items():
|
||||||
@@ -228,15 +279,19 @@ class PinaDataLoader:
|
|||||||
yield {split: batch}
|
yield {split: batch}
|
||||||
return
|
return
|
||||||
|
|
||||||
iterators = {split: iter(dl) for split, dl in self.dataloaders.items()}
|
iterators = {
|
||||||
|
split: itertools.cycle(dl) for split, dl in self.dataloaders.items()
|
||||||
|
}
|
||||||
|
|
||||||
for _ in range(len(self)):
|
for _ in range(len(self)):
|
||||||
batch_dict = {}
|
batch_dict = {}
|
||||||
for split, it in iterators.items():
|
for split, it in iterators.items():
|
||||||
try:
|
|
||||||
batch = next(it)
|
# Iterate through each dataloader and get the next batch
|
||||||
except StopIteration:
|
batch = next(it, None)
|
||||||
new_it = iter(self.dataloaders[split])
|
# Check if batch is None (in case of uneven lengths)
|
||||||
iterators[split] = new_it
|
if batch is None:
|
||||||
batch = next(new_it)
|
return
|
||||||
|
|
||||||
batch_dict[split] = batch
|
batch_dict[split] = batch
|
||||||
yield batch_dict
|
yield batch_dict
|
||||||
|
|||||||
@@ -9,26 +9,38 @@ from ..label_tensor import LabelTensor
|
|||||||
|
|
||||||
class PinaDatasetFactory:
|
class PinaDatasetFactory:
|
||||||
"""
|
"""
|
||||||
TODO: Update docstring
|
Factory class to create PINA datasets based on the provided conditions
|
||||||
|
dictionary.
|
||||||
|
:param dict conditions_dict: A dictionary where keys are condition names
|
||||||
|
and values are dictionaries containing the associated data.
|
||||||
|
:return: A dictionary mapping condition names to their respective
|
||||||
|
:class:`PinaDataset` instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, conditions_dict, **kwargs):
|
def __new__(cls, conditions_dict, **kwargs):
|
||||||
"""
|
"""
|
||||||
TODO: Update docstring
|
Create PINA dataset instances based on the provided conditions
|
||||||
|
dictionary.
|
||||||
|
:param dict conditions_dict: A dictionary where keys are condition names
|
||||||
|
and values are dictionaries containing the associated data.
|
||||||
|
:return: A dictionary mapping condition names to their respective
|
||||||
|
:class:`PinaDataset` instances.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Check if conditions_dict is empty
|
# Check if conditions_dict is empty
|
||||||
if len(conditions_dict) == 0:
|
if len(conditions_dict) == 0:
|
||||||
raise ValueError("No conditions provided")
|
raise ValueError("No conditions provided")
|
||||||
|
|
||||||
dataset_dict = {}
|
dataset_dict = {} # Dictionary to hold the created datasets
|
||||||
|
|
||||||
# Check is a Graph is present in the conditions
|
# Check is a Graph is present in the conditions
|
||||||
for name, data in conditions_dict.items():
|
for name, data in conditions_dict.items():
|
||||||
|
# Validate that data is a dictionary
|
||||||
if not isinstance(data, dict):
|
if not isinstance(data, dict):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Condition '{name}' data must be a dictionary"
|
f"Condition '{name}' data must be a dictionary"
|
||||||
)
|
)
|
||||||
|
# Create PinaDataset instance for each condition
|
||||||
dataset_dict[name] = PinaDataset(data, **kwargs)
|
dataset_dict[name] = PinaDataset(data, **kwargs)
|
||||||
return dataset_dict
|
return dataset_dict
|
||||||
|
|
||||||
@@ -90,7 +102,7 @@ class PinaDataset(Dataset):
|
|||||||
}
|
}
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
def _getitem_from_list(self, idx_list):
|
def getitem_from_list(self, idx_list):
|
||||||
"""
|
"""
|
||||||
Return data from the dataset given a list of indices.
|
Return data from the dataset given a list of indices.
|
||||||
|
|
||||||
@@ -101,7 +113,7 @@ class PinaDataset(Dataset):
|
|||||||
|
|
||||||
to_return = {}
|
to_return = {}
|
||||||
for field_name, data in self.data.items():
|
for field_name, data in self.data.items():
|
||||||
if self.stack_fn[field_name] == LabelBatch.from_data_list:
|
if self.stack_fn[field_name] is LabelBatch.from_data_list:
|
||||||
to_return[field_name] = self.stack_fn[field_name](
|
to_return[field_name] = self.stack_fn[field_name](
|
||||||
[data[i] for i in idx_list]
|
[data[i] for i in idx_list]
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user