fix data pipeline and add separeate_conditions option
This commit is contained in:
@@ -127,14 +127,14 @@ class PinaDataLoader:
|
||||
num_workers=0,
|
||||
collate_fn=None,
|
||||
common_batch_size=True,
|
||||
separate_conditions=False,
|
||||
):
|
||||
self.dataset_dict = dataset_dict
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.num_workers = num_workers
|
||||
self.collate_fn = collate_fn
|
||||
|
||||
print(batch_size)
|
||||
self.separate_conditions = separate_conditions
|
||||
|
||||
if batch_size is None:
|
||||
batch_size_per_dataset = {
|
||||
@@ -211,6 +211,8 @@ class PinaDataLoader:
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
if self.separate_conditions:
|
||||
return sum(len(dl) for dl in self.dataloaders.values())
|
||||
return max(len(dl) for dl in self.dataloaders.values())
|
||||
|
||||
def __iter__(self):
|
||||
@@ -220,26 +222,21 @@ class PinaDataLoader:
|
||||
Itera per un numero di passi pari al dataloader più lungo (come da __len__)
|
||||
e fa ricominciare i dataloader più corti quando si esauriscono.
|
||||
"""
|
||||
# 1. Crea un iteratore per ogni dataloader
|
||||
if self.separate_conditions:
|
||||
for split, dl in self.dataloaders.items():
|
||||
for batch in dl:
|
||||
yield {split: batch}
|
||||
return
|
||||
|
||||
iterators = {split: iter(dl) for split, dl in self.dataloaders.items()}
|
||||
|
||||
# 2. Itera per il numero di batch del dataloader più lungo
|
||||
for _ in range(len(self)):
|
||||
|
||||
# 3. Prepara il dizionario di batch per questo step
|
||||
batch_dict = {}
|
||||
|
||||
# 4. Ottieni il prossimo batch da ogni iteratore
|
||||
for split, it in iterators.items():
|
||||
try:
|
||||
batch = next(it)
|
||||
except StopIteration:
|
||||
# 5. Se un iteratore è esaurito, resettalo e prendi il primo batch
|
||||
new_it = iter(self.dataloaders[split])
|
||||
iterators[split] = new_it # Salva il nuovo iteratore
|
||||
iterators[split] = new_it
|
||||
batch = next(new_it)
|
||||
|
||||
batch_dict[split] = batch
|
||||
|
||||
# 6. Restituisci il dizionario di batch
|
||||
yield batch_dict
|
||||
|
||||
Reference in New Issue
Block a user