add stack conditions option

This commit is contained in:
FilippoOlivo
2025-11-22 12:59:22 +01:00
parent 0b877d86b9
commit 2521e4d2bd
2 changed files with 91 additions and 16 deletions

View File

@@ -1,11 +1,13 @@
"""DataLoader module for PinaDataset."""
import itertools
import random
from functools import partial
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import SequentialSampler
from .stacked_dataloader import StackedDataLoader
class DummyDataloader:
@@ -152,6 +154,22 @@ class PinaDataLoader:
Custom DataLoader for PinaDataset.
"""
def __new__(cls, *args, **kwargs):
batching_mode = kwargs.get("batching_mode", "common_batch_size").lower()
batch_size = kwargs.get("batch_size")
if batching_mode == "stacked" and batch_size is not None:
return StackedDataLoader(
args[0],
batch_size=batch_size,
shuffle=kwargs.get("shuffle", True),
)
elif batch_size is None:
kwargs["batching_mode"] = "proportional"
print(
"Using PinaDataLoader with batching mode:", kwargs["batching_mode"]
)
return super(PinaDataLoader, cls).__new__(cls)
def __init__(
self,
dataset_dict,
@@ -193,9 +211,10 @@ class PinaDataLoader:
# (the sum of the batch sizes is equal to
# n_conditions * batch_size)
batch_size_per_dataset = {
split: batch_size for split in dataset_dict.keys()
split: min(batch_size, len(ds))
for split, ds in dataset_dict.items()
}
elif batching_mode == "propotional":
elif batching_mode == "proportional":
# batch sizes is equal to the specified batch size)
batch_size_per_dataset = self._compute_batch_size()
@@ -296,30 +315,33 @@ class PinaDataLoader:
def __iter__(self):
"""
Iterate over the dataloader.
Iterate over the dataloader. Yields a dictionary mapping split name to batch.
:return: Yields batches from the dataloader.
:rtype: dict
The iteration logic for 'separate_conditions' is now iterative and memory-efficient.
"""
if self.batching_mode == "separate_conditions":
tmp = []
for split, dl in self.dataloaders.items():
for batch in dl:
yield {split: batch}
len_split = len(dl)
for i, batch in enumerate(dl):
tmp.append({split: batch})
if i + 1 >= len_split:
break
random.shuffle(tmp)
for batch_dict in tmp:
yield batch_dict
return
# Common_batch_size or Proportional mode (round-robin sampling)
iterators = {
split: itertools.cycle(dl) for split, dl in self.dataloaders.items()
}
# Iterate for the length of the longest dataloader
for _ in range(len(self)):
batch_dict = {}
batch_dict: BatchDict = {}
for split, it in iterators.items():
# Iterate through each dataloader and get the next batch
batch = next(it, None)
# Check if batch is None (in case of uneven lengths)
if batch is None:
return
batch_dict[split] = batch
# Since we use itertools.cycle, next(it) will always yield a batch
# by repeating the dataset, so no need for the 'if batch is None: return' check.
batch_dict[split] = next(it)
yield batch_dict

View File

@@ -0,0 +1,53 @@
import torch
from math import ceil
class StackedDataLoader:
def __init__(self, datasets, batch_size=32, shuffle=True):
for d in datasets.values():
if d.is_graph_dataset:
raise ValueError("Each dataset must be a dictionary")
self.chunks = {}
self.total_length = 0
self.indices = []
self._init_chunks(datasets)
self.indices = list(range(self.total_length))
self.batch_size = batch_size
self.shuffle = shuffle
if self.shuffle:
torch.random.manual_seed(42)
self.indices = torch.randperm(self.total_length).tolist()
self.datasets = datasets
def _init_chunks(self, datasets):
inc = 0
total_length = 0
for name, dataset in datasets.items():
self.chunks[name] = {"start": inc, "end": inc + len(dataset)}
inc += len(dataset)
self.total_length = inc
def __len__(self):
return ceil(self.total_length / self.batch_size)
def _build_batch_indices(self, batch_idx):
start = batch_idx * self.batch_size
end = min(start + self.batch_size, self.total_length)
return self.indices[start:end]
def __iter__(self):
for batch_idx in range(len(self)):
batch_indices = self._build_batch_indices(batch_idx)
batch_data = {}
for name, chunk in self.chunks.items():
local_indices = [
idx - chunk["start"]
for idx in batch_indices
if chunk["start"] <= idx < chunk["end"]
]
if local_indices:
batch_data[name] = self.datasets[name].getitem_from_list(
local_indices
)
yield batch_data