add stack conditions option

This commit is contained in:
FilippoOlivo
2025-11-22 12:59:22 +01:00
parent 0b877d86b9
commit 2521e4d2bd
2 changed files with 91 additions and 16 deletions

View File

@@ -1,11 +1,13 @@
"""DataLoader module for PinaDataset.""" """DataLoader module for PinaDataset."""
import itertools import itertools
import random
from functools import partial from functools import partial
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import SequentialSampler from torch.utils.data.sampler import SequentialSampler
from .stacked_dataloader import StackedDataLoader
class DummyDataloader: class DummyDataloader:
@@ -152,6 +154,22 @@ class PinaDataLoader:
Custom DataLoader for PinaDataset. Custom DataLoader for PinaDataset.
""" """
def __new__(cls, *args, **kwargs):
batching_mode = kwargs.get("batching_mode", "common_batch_size").lower()
batch_size = kwargs.get("batch_size")
if batching_mode == "stacked" and batch_size is not None:
return StackedDataLoader(
args[0],
batch_size=batch_size,
shuffle=kwargs.get("shuffle", True),
)
elif batch_size is None:
kwargs["batching_mode"] = "proportional"
print(
"Using PinaDataLoader with batching mode:", kwargs["batching_mode"]
)
return super(PinaDataLoader, cls).__new__(cls)
def __init__( def __init__(
self, self,
dataset_dict, dataset_dict,
@@ -193,9 +211,10 @@ class PinaDataLoader:
# (the sum of the batch sizes is equal to # (the sum of the batch sizes is equal to
# n_conditions * batch_size) # n_conditions * batch_size)
batch_size_per_dataset = { batch_size_per_dataset = {
split: batch_size for split in dataset_dict.keys() split: min(batch_size, len(ds))
for split, ds in dataset_dict.items()
} }
elif batching_mode == "propotional": elif batching_mode == "proportional":
# batch sizes is equal to the specified batch size) # batch sizes is equal to the specified batch size)
batch_size_per_dataset = self._compute_batch_size() batch_size_per_dataset = self._compute_batch_size()
@@ -296,30 +315,33 @@ class PinaDataLoader:
def __iter__(self): def __iter__(self):
""" """
Iterate over the dataloader. Iterate over the dataloader. Yields a dictionary mapping split name to batch.
:return: Yields batches from the dataloader. The iteration logic for 'separate_conditions' is now iterative and memory-efficient.
:rtype: dict
""" """
if self.batching_mode == "separate_conditions": if self.batching_mode == "separate_conditions":
tmp = []
for split, dl in self.dataloaders.items(): for split, dl in self.dataloaders.items():
for batch in dl: len_split = len(dl)
yield {split: batch} for i, batch in enumerate(dl):
tmp.append({split: batch})
if i + 1 >= len_split:
break
random.shuffle(tmp)
for batch_dict in tmp:
yield batch_dict
return return
# Common_batch_size or Proportional mode (round-robin sampling)
iterators = { iterators = {
split: itertools.cycle(dl) for split, dl in self.dataloaders.items() split: itertools.cycle(dl) for split, dl in self.dataloaders.items()
} }
# Iterate for the length of the longest dataloader
for _ in range(len(self)): for _ in range(len(self)):
batch_dict = {} batch_dict: BatchDict = {}
for split, it in iterators.items(): for split, it in iterators.items():
# Since we use itertools.cycle, next(it) will always yield a batch
# Iterate through each dataloader and get the next batch # by repeating the dataset, so no need for the 'if batch is None: return' check.
batch = next(it, None) batch_dict[split] = next(it)
# Check if batch is None (in case of uneven lengths)
if batch is None:
return
batch_dict[split] = batch
yield batch_dict yield batch_dict

View File

@@ -0,0 +1,53 @@
import torch
from math import ceil
class StackedDataLoader:
def __init__(self, datasets, batch_size=32, shuffle=True):
for d in datasets.values():
if d.is_graph_dataset:
raise ValueError("Each dataset must be a dictionary")
self.chunks = {}
self.total_length = 0
self.indices = []
self._init_chunks(datasets)
self.indices = list(range(self.total_length))
self.batch_size = batch_size
self.shuffle = shuffle
if self.shuffle:
torch.random.manual_seed(42)
self.indices = torch.randperm(self.total_length).tolist()
self.datasets = datasets
def _init_chunks(self, datasets):
inc = 0
total_length = 0
for name, dataset in datasets.items():
self.chunks[name] = {"start": inc, "end": inc + len(dataset)}
inc += len(dataset)
self.total_length = inc
def __len__(self):
return ceil(self.total_length / self.batch_size)
def _build_batch_indices(self, batch_idx):
start = batch_idx * self.batch_size
end = min(start + self.batch_size, self.total_length)
return self.indices[start:end]
def __iter__(self):
for batch_idx in range(len(self)):
batch_indices = self._build_batch_indices(batch_idx)
batch_data = {}
for name, chunk in self.chunks.items():
local_indices = [
idx - chunk["start"]
for idx in batch_indices
if chunk["start"] <= idx < chunk["end"]
]
if local_indices:
batch_data[name] = self.datasets[name].getitem_from_list(
local_indices
)
yield batch_data