From 2521e4d2bda7240c59334667cca3469be51ab219 Mon Sep 17 00:00:00 2001 From: FilippoOlivo Date: Sat, 22 Nov 2025 12:59:22 +0100 Subject: [PATCH] add stack conditions option --- pina/data/dataloader.py | 54 +++++++++++++++++++++++---------- pina/data/stacked_dataloader.py | 53 ++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 16 deletions(-) create mode 100644 pina/data/stacked_dataloader.py diff --git a/pina/data/dataloader.py b/pina/data/dataloader.py index b8d4a63..6feab1f 100644 --- a/pina/data/dataloader.py +++ b/pina/data/dataloader.py @@ -1,11 +1,13 @@ """DataLoader module for PinaDataset.""" import itertools +import random from functools import partial import torch from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import SequentialSampler +from .stacked_dataloader import StackedDataLoader class DummyDataloader: @@ -152,6 +154,22 @@ class PinaDataLoader: Custom DataLoader for PinaDataset. """ + def __new__(cls, *args, **kwargs): + batching_mode = kwargs.get("batching_mode", "common_batch_size").lower() + batch_size = kwargs.get("batch_size") + if batching_mode == "stacked" and batch_size is not None: + return StackedDataLoader( + args[0], + batch_size=batch_size, + shuffle=kwargs.get("shuffle", True), + ) + elif batch_size is None: + kwargs["batching_mode"] = "proportional" + print( + "Using PinaDataLoader with batching mode:", kwargs["batching_mode"] + ) + return super(PinaDataLoader, cls).__new__(cls) + def __init__( self, dataset_dict, @@ -193,9 +211,10 @@ class PinaDataLoader: # (the sum of the batch sizes is equal to # n_conditions * batch_size) batch_size_per_dataset = { - split: batch_size for split in dataset_dict.keys() + split: min(batch_size, len(ds)) + for split, ds in dataset_dict.items() } - elif batching_mode == "propotional": + elif batching_mode == "proportional": # batch sizes is equal to the specified batch size) batch_size_per_dataset = self._compute_batch_size() @@ -296,30 +315,33 @@ class PinaDataLoader: def __iter__(self): """ - Iterate over the dataloader. + Iterate over the dataloader. Yields a dictionary mapping split name to batch. - :return: Yields batches from the dataloader. - :rtype: dict + The iteration logic for 'separate_conditions' is now iterative and memory-efficient. """ if self.batching_mode == "separate_conditions": + tmp = [] for split, dl in self.dataloaders.items(): - for batch in dl: - yield {split: batch} + len_split = len(dl) + for i, batch in enumerate(dl): + tmp.append({split: batch}) + if i + 1 >= len_split: + break + random.shuffle(tmp) + for batch_dict in tmp: + yield batch_dict return + # Common_batch_size or Proportional mode (round-robin sampling) iterators = { split: itertools.cycle(dl) for split, dl in self.dataloaders.items() } + # Iterate for the length of the longest dataloader for _ in range(len(self)): - batch_dict = {} + batch_dict: BatchDict = {} for split, it in iterators.items(): - - # Iterate through each dataloader and get the next batch - batch = next(it, None) - # Check if batch is None (in case of uneven lengths) - if batch is None: - return - - batch_dict[split] = batch + # Since we use itertools.cycle, next(it) will always yield a batch + # by repeating the dataset, so no need for the 'if batch is None: return' check. + batch_dict[split] = next(it) yield batch_dict diff --git a/pina/data/stacked_dataloader.py b/pina/data/stacked_dataloader.py new file mode 100644 index 0000000..46bc547 --- /dev/null +++ b/pina/data/stacked_dataloader.py @@ -0,0 +1,53 @@ +import torch +from math import ceil + + +class StackedDataLoader: + def __init__(self, datasets, batch_size=32, shuffle=True): + for d in datasets.values(): + if d.is_graph_dataset: + raise ValueError("Each dataset must be a dictionary") + self.chunks = {} + self.total_length = 0 + self.indices = [] + + self._init_chunks(datasets) + self.indices = list(range(self.total_length)) + self.batch_size = batch_size + self.shuffle = shuffle + if self.shuffle: + torch.random.manual_seed(42) + self.indices = torch.randperm(self.total_length).tolist() + self.datasets = datasets + + def _init_chunks(self, datasets): + inc = 0 + total_length = 0 + for name, dataset in datasets.items(): + self.chunks[name] = {"start": inc, "end": inc + len(dataset)} + inc += len(dataset) + self.total_length = inc + + def __len__(self): + return ceil(self.total_length / self.batch_size) + + def _build_batch_indices(self, batch_idx): + start = batch_idx * self.batch_size + end = min(start + self.batch_size, self.total_length) + return self.indices[start:end] + + def __iter__(self): + for batch_idx in range(len(self)): + batch_indices = self._build_batch_indices(batch_idx) + batch_data = {} + for name, chunk in self.chunks.items(): + local_indices = [ + idx - chunk["start"] + for idx in batch_indices + if chunk["start"] <= idx < chunk["end"] + ] + if local_indices: + batch_data[name] = self.datasets[name].getitem_from_list( + local_indices + ) + yield batch_data