Files
PINA/pina/data/dataloader.py
2025-11-14 16:52:10 +01:00

321 lines
11 KiB
Python

"""DataLoader module for PinaDataset."""
import itertools
from functools import partial
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import SequentialSampler
class DummyDataloader:
"""
DataLoader that returns the entire dataset in a single batch.
"""
def __init__(self, dataset, device=None):
"""
Prepare a dataloader object that returns the entire dataset in a single
batch. Depending on the number of GPUs, the dataset is managed
as follows:
- **Distributed Environment** (multiple GPUs): Divides dataset across
processes using the rank and world size. Fetches only portion of
data corresponding to the current process.
- **Non-Distributed Environment** (single GPU): Fetches the entire
dataset.
:param PinaDataset dataset: The dataset object to be processed.
.. note::
This dataloader is used when the batch size is ``None``.
"""
# Handle distributed environment
if PinaSampler.is_distributed():
# Get rank and world size
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
# Ensure dataset is large enough
if len(dataset) < world_size:
raise RuntimeError(
"Dimension of the dataset smaller than world size."
" Increase the size of the partition or use a single GPU"
)
# Split dataset among processes
idx, i = [], rank
while i < len(dataset):
idx.append(i)
i += world_size
else:
idx = list(range(len(dataset)))
self.dataset = dataset.getitem_from_list(idx)
self.device = device
self.dataset = (
{k: v.to(self.device) for k, v in self.dataset.items()}
if self.device
else self.dataset
)
def __iter__(self):
"""
Iterate over the dataloader.
"""
return self
def __len__(self):
"""
Return the length of the dataloader, which is always 1.
:return: The length of the dataloader.
:rtype: int
"""
return 1
def __next__(self):
"""
Return the entire dataset as a single batch.
:return: The entire dataset.
:rtype: dict
"""
return self.dataset
class PinaSampler:
"""
This class is used to create the sampler instance based on the shuffle
parameter and the environment in which the code is running.
"""
def __new__(cls, dataset, shuffle=True):
"""
Instantiate and initialize the sampler.
:param PinaDataset dataset: The dataset from which to sample.
:return: The sampler instance.
:rtype: :class:`torch.utils.data.Sampler`
"""
if cls.is_distributed():
sampler = DistributedSampler(dataset, shuffle=shuffle)
else:
if shuffle:
sampler = torch.utils.data.RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
return sampler
@staticmethod
def is_distributed():
"""
Check if the sampler is distributed.
:return: True if the sampler is distributed, False otherwise.
:rtype: bool
"""
return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
)
def _collect_items(batch):
"""
Helper function to collect items from a batch of graph data samples.
:param batch: List of graph data samples.
"""
to_return = {name: [] for name in batch[0].keys()}
for sample in batch:
for k, v in sample.items():
to_return[k].append(v)
return to_return
def collate_fn_custom(batch, dataset):
"""
Override the default collate function to handle datasets without automatic
batching.
:param batch: List of indices from the dataset.
:param dataset: The PinaDataset instance (must be provided).
"""
return dataset.getitem_from_list(batch)
def collate_fn_default(batch, stack_fn):
"""
Default collate function that simply returns the batch as is.
:param batch: List of data samples.
"""
to_return = _collect_items(batch)
return {k: stack_fn[k](v) for k, v in to_return.items()}
class PinaDataLoader:
"""
Custom DataLoader for PinaDataset.
"""
def __init__(
self,
dataset_dict,
batch_size,
num_workers=0,
shuffle=False,
batching_mode="common_batch_size",
device=None,
):
"""
Initialize the PinaDataLoader.
:param dict dataset_dict: A dictionary mapping dataset names to their
respective PinaDataset instances.
:param int batch_size: The batch size for the dataloader.
:param int num_workers: Number of worker processes for data loading.
:param bool shuffle: Whether to shuffle the data at every epoch.
:param str batching_mode: The batching mode to use. Options are
"common_batch_size", "separate_conditions", and "proportional".
:param device: The device to which the data should be moved.
"""
self.dataset_dict = dataset_dict
self.batch_size = batch_size
self.num_workers = num_workers
self.shuffle = shuffle
self.batching_mode = batching_mode.lower()
self.device = device
# Batch size None means we want to load the entire dataset in a single
# batch
if batch_size is None:
batch_size_per_dataset = {
split: None for split in dataset_dict.keys()
}
else:
# Compute batch size per dataset
if batching_mode in ["common_batch_size", "separate_conditions"]:
# (the sum of the batch sizes is equal to
# n_conditions * batch_size)
batch_size_per_dataset = {
split: batch_size for split in dataset_dict.keys()
}
elif batching_mode == "propotional":
# batch sizes is equal to the specified batch size)
batch_size_per_dataset = self._compute_batch_size()
# Creaete a dataloader per dataset
self.dataloaders = {
split: self._create_dataloader(
dataset, batch_size_per_dataset[split]
)
for split, dataset in dataset_dict.items()
}
def _compute_batch_size(self):
"""
Compute an appropriate batch size for the given dataset.
"""
# Compute number of elements per dataset
elements_per_dataset = {
dataset_name: len(dataset)
for dataset_name, dataset in self.dataset_dict.items()
}
# Compute the total number of elements
total_elements = sum(el for el in elements_per_dataset.values())
# Compute the portion of each dataset
portion_per_dataset = {
name: el / total_elements
for name, el in elements_per_dataset.items()
}
# Compute batch size per dataset. Ensure at least 1 element per
# dataset.
batch_size_per_dataset = {
name: max(1, int(portion * self.batch_size))
for name, portion in portion_per_dataset.items()
}
# Adjust batch sizes to match the specified total batch size
tot_el_per_batch = sum(el for el in batch_size_per_dataset.values())
if self.batch_size > tot_el_per_batch:
difference = self.batch_size - tot_el_per_batch
while difference > 0:
for k, v in batch_size_per_dataset.items():
if difference == 0:
break
if v > 1:
batch_size_per_dataset[k] += 1
difference -= 1
if self.batch_size < tot_el_per_batch:
difference = tot_el_per_batch - self.batch_size
while difference > 0:
for k, v in batch_size_per_dataset.items():
if difference == 0:
break
if v > 1:
batch_size_per_dataset[k] -= 1
difference -= 1
return batch_size_per_dataset
def _create_dataloader(self, dataset, batch_size):
"""
Create the dataloader for the given dataset.
:param PinaDataset dataset: The dataset for which to create the
dataloader.
:param int batch_size: The batch size for the dataloader.
:return: The created dataloader.
:rtype: :class:`torch.utils.data.DataLoader`
"""
# If batch size is None, use DummyDataloader
if batch_size is None or batch_size >= len(dataset):
return DummyDataloader(dataset, device=self.device)
# Determine the appropriate collate function
if not dataset.automatic_batching:
collate_fn = partial(collate_fn_custom, dataset=dataset)
else:
collate_fn = partial(collate_fn_default, stack_fn=dataset.stack_fn)
# Create and return the dataloader
return DataLoader(
dataset,
batch_size=batch_size,
collate_fn=collate_fn,
num_workers=self.num_workers,
sampler=PinaSampler(dataset, shuffle=self.shuffle),
)
def __len__(self):
"""
Return the length of the dataloader.
:return: The length of the dataloader.
:rtype: int
"""
# If separate conditions, return sum of lengths of all dataloaders
# else, return max length among dataloaders
if self.batching_mode == "separate_conditions":
return sum(len(dl) for dl in self.dataloaders.values())
return max(len(dl) for dl in self.dataloaders.values())
def __iter__(self):
"""
Iterate over the dataloader.
:return: Yields batches from the dataloader.
:rtype: dict
"""
if self.batching_mode == "separate_conditions":
for split, dl in self.dataloaders.items():
for batch in dl:
yield {split: batch}
return
iterators = {
split: itertools.cycle(dl) for split, dl in self.dataloaders.items()
}
for _ in range(len(self)):
batch_dict = {}
for split, it in iterators.items():
# Iterate through each dataloader and get the next batch
batch = next(it, None)
# Check if batch is None (in case of uneven lengths)
if batch is None:
return
batch_dict[split] = batch
yield batch_dict