add stack conditions option

This commit is contained in:
FilippoOlivo
2025-11-22 12:59:22 +01:00
parent 0b877d86b9
commit 2521e4d2bd
2 changed files with 91 additions and 16 deletions

View File

@@ -0,0 +1,53 @@
import torch
from math import ceil
class StackedDataLoader:
def __init__(self, datasets, batch_size=32, shuffle=True):
for d in datasets.values():
if d.is_graph_dataset:
raise ValueError("Each dataset must be a dictionary")
self.chunks = {}
self.total_length = 0
self.indices = []
self._init_chunks(datasets)
self.indices = list(range(self.total_length))
self.batch_size = batch_size
self.shuffle = shuffle
if self.shuffle:
torch.random.manual_seed(42)
self.indices = torch.randperm(self.total_length).tolist()
self.datasets = datasets
def _init_chunks(self, datasets):
inc = 0
total_length = 0
for name, dataset in datasets.items():
self.chunks[name] = {"start": inc, "end": inc + len(dataset)}
inc += len(dataset)
self.total_length = inc
def __len__(self):
return ceil(self.total_length / self.batch_size)
def _build_batch_indices(self, batch_idx):
start = batch_idx * self.batch_size
end = min(start + self.batch_size, self.total_length)
return self.indices[start:end]
def __iter__(self):
for batch_idx in range(len(self)):
batch_indices = self._build_batch_indices(batch_idx)
batch_data = {}
for name, chunk in self.chunks.items():
local_indices = [
idx - chunk["start"]
for idx in batch_indices
if chunk["start"] <= idx < chunk["end"]
]
if local_indices:
batch_data[name] = self.datasets[name].getitem_from_list(
local_indices
)
yield batch_data