fix data pipeline and add separeate_conditions option

This commit is contained in:
FilippoOlivo
2025-11-12 15:59:28 +01:00
parent 99e2f07cf7
commit 4d172a8821
3 changed files with 30 additions and 137 deletions

View File

@@ -127,14 +127,14 @@ class PinaDataLoader:
num_workers=0,
collate_fn=None,
common_batch_size=True,
separate_conditions=False,
):
self.dataset_dict = dataset_dict
self.batch_size = batch_size
self.shuffle = shuffle
self.num_workers = num_workers
self.collate_fn = collate_fn
print(batch_size)
self.separate_conditions = separate_conditions
if batch_size is None:
batch_size_per_dataset = {
@@ -211,6 +211,8 @@ class PinaDataLoader:
)
def __len__(self):
if self.separate_conditions:
return sum(len(dl) for dl in self.dataloaders.values())
return max(len(dl) for dl in self.dataloaders.values())
def __iter__(self):
@@ -220,26 +222,21 @@ class PinaDataLoader:
Itera per un numero di passi pari al dataloader più lungo (come da __len__)
e fa ricominciare i dataloader più corti quando si esauriscono.
"""
# 1. Crea un iteratore per ogni dataloader
if self.separate_conditions:
for split, dl in self.dataloaders.items():
for batch in dl:
yield {split: batch}
return
iterators = {split: iter(dl) for split, dl in self.dataloaders.items()}
# 2. Itera per il numero di batch del dataloader più lungo
for _ in range(len(self)):
# 3. Prepara il dizionario di batch per questo step
batch_dict = {}
# 4. Ottieni il prossimo batch da ogni iteratore
for split, it in iterators.items():
try:
batch = next(it)
except StopIteration:
# 5. Se un iteratore è esaurito, resettalo e prendi il primo batch
new_it = iter(self.dataloaders[split])
iterators[split] = new_it # Salva il nuovo iteratore
iterators[split] = new_it
batch = next(new_it)
batch_dict[split] = batch
# 6. Restituisci il dizionario di batch
yield batch_dict