add stack conditions option

This commit is contained in:
FilippoOlivo
2025-11-22 12:59:22 +01:00
parent 0b877d86b9
commit 2521e4d2bd
2 changed files with 91 additions and 16 deletions

View File

@@ -1,11 +1,13 @@
"""DataLoader module for PinaDataset."""
import itertools
import random
from functools import partial
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import SequentialSampler
from .stacked_dataloader import StackedDataLoader
class DummyDataloader:
@@ -152,6 +154,22 @@ class PinaDataLoader:
Custom DataLoader for PinaDataset.
"""
def __new__(cls, *args, **kwargs):
batching_mode = kwargs.get("batching_mode", "common_batch_size").lower()
batch_size = kwargs.get("batch_size")
if batching_mode == "stacked" and batch_size is not None:
return StackedDataLoader(
args[0],
batch_size=batch_size,
shuffle=kwargs.get("shuffle", True),
)
elif batch_size is None:
kwargs["batching_mode"] = "proportional"
print(
"Using PinaDataLoader with batching mode:", kwargs["batching_mode"]
)
return super(PinaDataLoader, cls).__new__(cls)
def __init__(
self,
dataset_dict,
@@ -193,9 +211,10 @@ class PinaDataLoader:
# (the sum of the batch sizes is equal to
# n_conditions * batch_size)
batch_size_per_dataset = {
split: batch_size for split in dataset_dict.keys()
split: min(batch_size, len(ds))
for split, ds in dataset_dict.items()
}
elif batching_mode == "propotional":
elif batching_mode == "proportional":
# batch sizes is equal to the specified batch size)
batch_size_per_dataset = self._compute_batch_size()
@@ -296,30 +315,33 @@ class PinaDataLoader:
def __iter__(self):
"""
Iterate over the dataloader.
Iterate over the dataloader. Yields a dictionary mapping split name to batch.
:return: Yields batches from the dataloader.
:rtype: dict
The iteration logic for 'separate_conditions' is now iterative and memory-efficient.
"""
if self.batching_mode == "separate_conditions":
tmp = []
for split, dl in self.dataloaders.items():
for batch in dl:
yield {split: batch}
len_split = len(dl)
for i, batch in enumerate(dl):
tmp.append({split: batch})
if i + 1 >= len_split:
break
random.shuffle(tmp)
for batch_dict in tmp:
yield batch_dict
return
# Common_batch_size or Proportional mode (round-robin sampling)
iterators = {
split: itertools.cycle(dl) for split, dl in self.dataloaders.items()
}
# Iterate for the length of the longest dataloader
for _ in range(len(self)):
batch_dict = {}
batch_dict: BatchDict = {}
for split, it in iterators.items():
# Iterate through each dataloader and get the next batch
batch = next(it, None)
# Check if batch is None (in case of uneven lengths)
if batch is None:
return
batch_dict[split] = batch
# Since we use itertools.cycle, next(it) will always yield a batch
# by repeating the dataset, so no need for the 'if batch is None: return' check.
batch_dict[split] = next(it)
yield batch_dict